Skip to content

Commit

Permalink
Add filtered row-level result support for PatternMatch
Browse files Browse the repository at this point in the history
  • Loading branch information
eycho-am committed Feb 20, 2024
1 parent 0a55107 commit ad7a8cc
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 19 deletions.
20 changes: 18 additions & 2 deletions src/main/scala/com/amazon/deequ/analyzers/PatternMatch.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package com.amazon.deequ.analyzers
import com.amazon.deequ.analyzers.Analyzers._
import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isString}
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.not
import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.functions.{col, lit, regexp_extract, sum, when}
import org.apache.spark.sql.types.{BooleanType, IntegerType, StructType}
Expand All @@ -36,13 +38,14 @@ import scala.util.matching.Regex
* @param pattern The regular expression to check for
* @param where Additional filter to apply before the analyzer is run.
*/
case class PatternMatch(column: String, pattern: Regex, where: Option[String] = None)
case class PatternMatch(column: String, pattern: Regex, where: Option[String] = None,
analyzerOptions: Option[AnalyzerOptions] = None)
extends StandardScanShareableAnalyzer[NumMatchesAndCount]("PatternMatch", column)
with FilterableAnalyzer {

override def fromAggregationResult(result: Row, offset: Int): Option[NumMatchesAndCount] = {
ifNoNullsIn(result, offset, howMany = 2) { _ =>
NumMatchesAndCount(result.getLong(offset), result.getLong(offset + 1), Some(criterion.cast(BooleanType)))
NumMatchesAndCount(result.getLong(offset), result.getLong(offset + 1), Some(rowLevelResults.cast(BooleanType)))
}
}

Expand Down Expand Up @@ -82,6 +85,19 @@ case class PatternMatch(column: String, pattern: Regex, where: Option[String] =
conditionalSelection(expression, where).cast(IntegerType)
}

private[deequ] def rowLevelResults: Column = {
val filteredRowOutcome = getRowLevelFilterTreatment(analyzerOptions)
val whereNotCondition = where.map { expression => not(expr(expression)) }
val expression = when(regexp_extract(col(column), pattern.toString(), 0) =!= lit(""), 1)
.otherwise(0)

filteredRowOutcome match {
case FilteredRowOutcome.TRUE =>
conditionSelectionGivenColumn(expression, whereNotCondition, replaceWith = 1).cast(IntegerType)
case _ =>
criterion
}
}

}

Expand Down
6 changes: 4 additions & 2 deletions src/main/scala/com/amazon/deequ/checks/Check.scala
Original file line number Diff line number Diff line change
Expand Up @@ -839,18 +839,20 @@ case class Check(
* @param pattern The columns values will be checked for a match against this pattern.
* @param assertion Function that receives a double input parameter and returns a boolean
* @param hint A hint to provide additional context why a constraint could have failed
* @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow)
* @return
*/
def hasPattern(
column: String,
pattern: Regex,
assertion: Double => Boolean = Check.IsOne,
name: Option[String] = None,
hint: Option[String] = None)
hint: Option[String] = None,
analyzerOptions: Option[AnalyzerOptions] = None)
: CheckWithLastConstraintFilterable = {

addFilterableConstraint { filter =>
Constraint.patternMatchConstraint(column, pattern, assertion, filter, name, hint)
Constraint.patternMatchConstraint(column, pattern, assertion, filter, name, hint, analyzerOptions)
}
}

Expand Down
6 changes: 4 additions & 2 deletions src/main/scala/com/amazon/deequ/constraints/Constraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -373,17 +373,19 @@ object Constraint {
* @param pattern The regex pattern to check compliance for
* @param column Data frame column which is a combination of expression and the column name
* @param hint A hint to provide additional context why a constraint could have failed
* @param analyzerOptions Options to configure analyzer behavior (NullTreatment, FilteredRow)
*/
def patternMatchConstraint(
column: String,
pattern: Regex,
assertion: Double => Boolean,
where: Option[String] = None,
name: Option[String] = None,
hint: Option[String] = None)
hint: Option[String] = None,
analyzerOptions: Option[AnalyzerOptions] = None)
: Constraint = {

val patternMatch = PatternMatch(column, pattern, where)
val patternMatch = PatternMatch(column, pattern, where, analyzerOptions)

fromAnalyzer(patternMatch, pattern, assertion, name, hint)
}
Expand Down
24 changes: 22 additions & 2 deletions src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -321,12 +321,16 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec
val max = new Check(CheckLevel.Error, "rule5")
.hasMax("item", _ < 4, None)
.where("item < 4")
val patternMatch = new Check(CheckLevel.Error, "rule6")
.hasPattern("att2", """(^f)""".r)
.where("item < 4")

val expectedColumn1 = completeness.description
val expectedColumn2 = uniqueness.description
val expectedColumn3 = uniquenessWhere.description
val expectedColumn4 = min.description
val expectedColumn5 = max.description
val expectedColumn6 = patternMatch.description


val suite = new VerificationSuite().onData(data)
Expand All @@ -335,6 +339,7 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec
.addCheck(uniquenessWhere)
.addCheck(min)
.addCheck(max)
.addCheck(patternMatch)

val result: VerificationResult = suite.run()

Expand All @@ -343,7 +348,8 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec
val resultData = VerificationResult.rowLevelResultsAsDataFrame(session, result, data).orderBy("item")
resultData.show(false)
val expectedColumns: Set[String] =
data.columns.toSet + expectedColumn1 + expectedColumn2 + expectedColumn3 + expectedColumn4 + expectedColumn5
data.columns.toSet + expectedColumn1 + expectedColumn2 + expectedColumn3 +
expectedColumn4 + expectedColumn5 + expectedColumn6
assert(resultData.columns.toSet == expectedColumns)

// filtered rows 2,5 (where att1 = "a")
Expand All @@ -364,6 +370,10 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec
// filtered rows 4, 5, 6 (where item < 4)
val maxRowLevel = resultData.select(expectedColumn5).collect().map(r => r.getAs[Any](0))
assert(Seq(true, true, true, true, true, true).sameElements(maxRowLevel))

// filtered rows 4, 5, 6 (where item < 4)
val rowLevel6 = resultData.select(expectedColumn6).collect().map(r => r.getAs[Any](0))
assert(Seq(true, false, false, true, true, true).sameElements(rowLevel6))
}

"generate a result that contains row-level results with null for filtered rows" in withSparkSession { session =>
Expand All @@ -385,19 +395,24 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec
val max = new Check(CheckLevel.Error, "rule5")
.hasMax("item", _ < 4, None, analyzerOptions)
.where("item < 4")
val patternMatch = new Check(CheckLevel.Error, "rule6")
.hasPattern("att2", """(^f)""".r, analyzerOptions = analyzerOptions)
.where("item < 4")

val expectedColumn1 = completeness.description
val expectedColumn2 = uniqueness.description
val expectedColumn3 = uniquenessWhere.description
val expectedColumn4 = min.description
val expectedColumn5 = max.description
val expectedColumn6 = patternMatch.description

val suite = new VerificationSuite().onData(data)
.addCheck(completeness)
.addCheck(uniqueness)
.addCheck(uniquenessWhere)
.addCheck(min)
.addCheck(max)
.addCheck(patternMatch)

val result: VerificationResult = suite.run()

Expand All @@ -406,7 +421,8 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec
val resultData = VerificationResult.rowLevelResultsAsDataFrame(session, result, data).orderBy("item")
resultData.show(false)
val expectedColumns: Set[String] =
data.columns.toSet + expectedColumn1 + expectedColumn2 + expectedColumn3 + expectedColumn4 + expectedColumn5
data.columns.toSet + expectedColumn1 + expectedColumn2 + expectedColumn3 +
expectedColumn4 + expectedColumn5 + expectedColumn6
assert(resultData.columns.toSet == expectedColumns)

val rowLevel1 = resultData.select(expectedColumn1).collect().map(r => r.getAs[Any](0))
Expand All @@ -426,6 +442,10 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec
// filtered rows 4, 5, 6 (where item < 4)
val rowLevel5 = resultData.select(expectedColumn5).collect().map(r => r.getAs[Any](0))
assert(Seq(true, true, true, null, null, null).sameElements(rowLevel5))

// filtered rows 4, 5, 6 (where item < 4)
val rowLevel6 = resultData.select(expectedColumn6).collect().map(r => r.getAs[Any](0))
assert(Seq(true, false, false, null, null, null).sameElements(rowLevel6))
}

"generate a result that contains compliance row-level results " in withSparkSession { session =>
Expand Down
56 changes: 54 additions & 2 deletions src/test/scala/com/amazon/deequ/analyzers/PatternMatchTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,35 @@ class PatternMatchTest extends AnyWordSpec with Matchers with SparkContextSpec w
val state = patternMatchCountry.computeStateFrom(data)
val metric: DoubleMetric with FullColumn = patternMatchCountry.computeMetricFrom(state)

data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Boolean]("new")) shouldBe
data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Any]("new")) shouldBe
Seq(true, true, true, true, true, true, true, true)
}

"return row-level results for non-null columns starts with digit" in withSparkSession { session =>

val data = getDfWithStringColumns(session)

val patternMatchCountry = PatternMatch("Address Line 1", """(^[0-4])""".r)
val state = patternMatchCountry.computeStateFrom(data)
val metric: DoubleMetric with FullColumn = patternMatchCountry.computeMetricFrom(state)

data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Any]("new")) shouldBe
Seq(false, false, true, true, false, false, true, true)
}

"return row-level results for non-null columns starts with digit filtered as true" in withSparkSession { session =>

val data = getDfWithStringColumns(session)

val patternMatchCountry = PatternMatch("Address Line 1", """(^[0-4])""".r, where = Option("id < 5"),
analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.TRUE)))
val state = patternMatchCountry.computeStateFrom(data)
val metric: DoubleMetric with FullColumn = patternMatchCountry.computeMetricFrom(state)

data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Any]("new")) shouldBe
Seq(false, false, true, true, false, true, true, true)
}

"return row-level results for columns with nulls" in withSparkSession { session =>

val data = getDfWithStringColumns(session)
Expand All @@ -45,8 +70,35 @@ class PatternMatchTest extends AnyWordSpec with Matchers with SparkContextSpec w
val state = patternMatchCountry.computeStateFrom(data)
val metric: DoubleMetric with FullColumn = patternMatchCountry.computeMetricFrom(state)

data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Boolean]("new")) shouldBe
data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Any]("new")) shouldBe
Seq(true, true, true, true, false, true, true, false)
}

"return row-level results for columns with nulls filtered as true" in withSparkSession { session =>

val data = getDfWithStringColumns(session)

val patternMatchCountry = PatternMatch("Address Line 2", """\w""".r, where = Option("id < 5"),
analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.TRUE)))
val state = patternMatchCountry.computeStateFrom(data)
val metric: DoubleMetric with FullColumn = patternMatchCountry.computeMetricFrom(state)

data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Any]("new")) shouldBe
Seq(true, true, true, true, false, true, true, true)
}

"return row-level results for columns with nulls filtered as null" in withSparkSession { session =>

val data = getDfWithStringColumns(session)

val patternMatchCountry = PatternMatch("Address Line 2", """\w""".r, where = Option("id < 5"),
analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.NULL)))
val state = patternMatchCountry.computeStateFrom(data)
val metric: DoubleMetric with FullColumn = patternMatchCountry.computeMetricFrom(state)

println(metric.fullColumn)
data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Any]("new")) shouldBe
Seq(true, true, true, true, false, null, null, null)
}
}
}
18 changes: 9 additions & 9 deletions src/test/scala/com/amazon/deequ/utils/FixtureSupport.scala
Original file line number Diff line number Diff line change
Expand Up @@ -412,16 +412,16 @@ trait FixtureSupport {
import sparkSession.implicits._

Seq(
("India", "Xavier House, 2nd Floor", "St. Peter Colony, Perry Road", "Bandra (West)"),
("India", "503 Godavari", "Sir Pochkhanwala Road", "Worli"),
("India", "4/4 Seema Society", "N Dutta Road, Four Bungalows", "Andheri"),
("India", "1001D Abhishek Apartments", "Juhu Versova Road", "Andheri"),
("India", "95, Hill Road", null, null),
("India", "90 Cuffe Parade", "Taj President Hotel", "Cuffe Parade"),
("India", "4, Seven PM", "Sir Pochkhanwala Rd", "Worli"),
("India", "1453 Sahar Road", null, null)
(0, "India", "Xavier House, 2nd Floor", "St. Peter Colony, Perry Road", "Bandra (West)"),
(1, "India", "503 Godavari", "Sir Pochkhanwala Road", "Worli"),
(2, "India", "4/4 Seema Society", "N Dutta Road, Four Bungalows", "Andheri"),
(3, "India", "1001D Abhishek Apartments", "Juhu Versova Road", "Andheri"),
(4, "India", "95, Hill Road", null, null),
(5, "India", "90 Cuffe Parade", "Taj President Hotel", "Cuffe Parade"),
(6, "India", "4, Seven PM", "Sir Pochkhanwala Rd", "Worli"),
(7, "India", "1453 Sahar Road", null, null)
)
.toDF("Country", "Address Line 1", "Address Line 2", "Address Line 3")
.toDF("id", "Country", "Address Line 1", "Address Line 2", "Address Line 3")
}

def getDfWithPeriodInName(sparkSession: SparkSession): DataFrame = {
Expand Down

0 comments on commit ad7a8cc

Please sign in to comment.