-
Notifications
You must be signed in to change notification settings - Fork 513
/
CustomSql.scala
94 lines (85 loc) · 3.31 KB
/
CustomSql.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
/**
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
* use this file except in compliance with the License. A copy of the License
* is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*
*/
package com.amazon.deequ.analyzers
import com.amazon.deequ.metrics.DoubleMetric
import com.amazon.deequ.metrics.Entity
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.DoubleType
import scala.util.Failure
import scala.util.Success
import scala.util.Try
case class CustomSqlState(stateOrError: Either[Double, String]) extends DoubleValuedState[CustomSqlState] {
lazy val state = stateOrError.left.get
lazy val error = stateOrError.right.get
override def sum(other: CustomSqlState): CustomSqlState = {
CustomSqlState(Left(state + other.state))
}
override def metricValue(): Double = state
}
case class CustomSql(expression: String, disambiguator: String = "*") extends Analyzer[CustomSqlState, DoubleMetric] {
/**
* Compute the state (sufficient statistics) from the data
*
* @param data data frame
* @return
*/
override def computeStateFrom(data: DataFrame, filterCondition: Option[String] = None): Option[CustomSqlState] = {
Try {
data.sqlContext.sql(expression)
} match {
case Failure(e) => Some(CustomSqlState(Right(e.getMessage)))
case Success(dfSql) =>
val cols = dfSql.columns.toSeq
cols match {
case Seq(resultCol) =>
val dfSqlCast = dfSql.withColumn(resultCol, col(resultCol).cast(DoubleType))
val results: Seq[Row] = dfSqlCast.collect()
if (results.size != 1) {
Some(CustomSqlState(Right("Custom SQL did not return exactly 1 row")))
} else {
Some(CustomSqlState(Left(results.head.get(0).asInstanceOf[Double])))
}
case _ => Some(CustomSqlState(Right("Custom SQL did not return exactly 1 column")))
}
}
}
/**
* Compute the metric from the state (sufficient statistics)
*
* @param state wrapper holding a state of type S (required due to typing issues...)
* @return
*/
override def computeMetricFrom(state: Option[CustomSqlState]): DoubleMetric = {
state match {
// The returned state may
case Some(theState) => theState.stateOrError match {
case Left(value) => DoubleMetric(Entity.Dataset, "CustomSQL", disambiguator,
Success(value))
case Right(error) => DoubleMetric(Entity.Dataset, "CustomSQL", disambiguator,
Failure(new RuntimeException(error)))
}
case None =>
DoubleMetric(Entity.Dataset, "CustomSQL", disambiguator,
Failure(new RuntimeException("CustomSql Failed To Run")))
}
}
override private[deequ] def toFailureMetric(failure: Exception) = {
DoubleMetric(Entity.Dataset, "CustomSQL", disambiguator,
Failure(new RuntimeException("CustomSql Failed To Run")))
}
}