Skip to content

Commit

Permalink
Feature/100 0 n columns (#111)
Browse files Browse the repository at this point in the history
#100: support for 0-n columns
  • Loading branch information
salamonpavel committed Jan 2, 2024
1 parent 288e9e1 commit 4057443
Show file tree
Hide file tree
Showing 12 changed files with 327 additions and 185 deletions.
1 change: 1 addition & 0 deletions agent/src/main/scala/za/co/absa/atum/agent/AtumAgent.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

package za.co.absa.atum.agent

import com.typesafe.config.{Config, ConfigFactory}
import za.co.absa.atum.agent.AtumContext.AtumPartitions
import za.co.absa.atum.agent.dispatcher.{ConsoleDispatcher, HttpDispatcher}
Expand Down
13 changes: 6 additions & 7 deletions agent/src/main/scala/za/co/absa/atum/agent/AtumContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import za.co.absa.atum.agent.model._
import za.co.absa.atum.model.dto._

import java.time.ZonedDateTime

import java.util.UUID
import scala.collection.immutable.ListMap

Expand All @@ -37,7 +36,7 @@ import scala.collection.immutable.ListMap
class AtumContext private[agent] (
val atumPartitions: AtumPartitions,
val agent: AtumAgent,
private var measures: Set[Measure] = Set.empty,
private var measures: Set[AtumMeasure] = Set.empty,
private var additionalData: Map[String, Option[String]] = Map.empty
) {

Expand All @@ -46,7 +45,7 @@ class AtumContext private[agent] (
*
* @return the current set of measures
*/
def currentMeasures: Set[Measure] = measures
def currentMeasures: Set[AtumMeasure] = measures

/**
* Returns the sub-partition context in the AtumContext.
Expand Down Expand Up @@ -145,7 +144,7 @@ class AtumContext private[agent] (
*
* @param measure the measure to be added
*/
def addMeasure(newMeasure: Measure): AtumContext = {
def addMeasure(newMeasure: AtumMeasure): AtumContext = {
measures = measures + newMeasure
this
}
Expand All @@ -155,7 +154,7 @@ class AtumContext private[agent] (
*
* @param measures set sequence of measures to be added
*/
def addMeasures(newMeasures: Set[Measure]): AtumContext = {
def addMeasures(newMeasures: Set[AtumMeasure]): AtumContext = {
measures = measures ++ newMeasures
this
}
Expand All @@ -165,15 +164,15 @@ class AtumContext private[agent] (
*
* @param measureToRemove the measure to be removed
*/
def removeMeasure(measureToRemove: Measure): AtumContext = {
def removeMeasure(measureToRemove: AtumMeasure): AtumContext = {
measures = measures - measureToRemove
this
}

private[agent] def copy(
atumPartitions: AtumPartitions = this.atumPartitions,
agent: AtumAgent = this.agent,
measures: Set[Measure] = this.measures,
measures: Set[AtumMeasure] = this.measures,
additionalData: Map[String, Option[String]] = this.additionalData
): AtumContext = {
new AtumContext(atumPartitions, agent, measures, additionalData)
Expand Down
213 changes: 84 additions & 129 deletions agent/src/main/scala/za/co/absa/atum/agent/model/Measure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,197 +17,152 @@
package za.co.absa.atum.agent.model

import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DecimalType, LongType, StringType}
import org.apache.spark.sql.types.{DataType, DecimalType, LongType, StringType}
import org.apache.spark.sql.{Column, DataFrame}
import za.co.absa.atum.agent.core.MeasurementProcessor
import za.co.absa.atum.agent.core.MeasurementProcessor.{MeasurementFunction, ResultOfMeasurement}
import za.co.absa.atum.model.dto.MeasureResultDTO.ResultValueType
import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancements

/**
* This trait represents a measure that can be applied to a column.
* Type of different measures to be applied to the columns.
*/
sealed trait Measure extends MeasurementProcessor with MeasureType {
val measuredColumn: String
sealed trait Measure {
val measureName: String
def controlColumns: Seq[String]
}

/**
* This trait represents a measure type that can be applied to a column.
*/
trait MeasureType {
val measureName: String
trait AtumMeasure extends Measure with MeasurementProcessor {
val resultValueType: ResultValueType.ResultValueType
}

/**
* This object contains all the possible measures that can be applied to a column.
*/
object Measure {

private val valueColumnName: String = "value"
object AtumMeasure {

val supportedMeasures: Seq[MeasureType] = Seq(
RecordCount,
DistinctRecordCount,
SumOfValuesOfColumn,
AbsSumOfValuesOfColumn,
SumOfHashesOfColumn
val supportedMeasureNames: Seq[String] = Seq(
RecordCount.measureName,
DistinctRecordCount.measureName,
SumOfValuesOfColumn.measureName,
AbsSumOfValuesOfColumn.measureName,
SumOfHashesOfColumn.measureName
)
val supportedMeasureNames: Seq[String] = supportedMeasures.map(_.measureName)

case class RecordCount private (
measuredColumn: String,
measureName: String,
resultValueType: ResultValueType.ResultValueType
) extends Measure {
case class RecordCount private (measureName: String) extends AtumMeasure {
private val columnExpression = count("*")

override def function: MeasurementFunction =
(ds: DataFrame) => {
val resultValue = ds.select(col(measuredColumn)).count().toString
ResultOfMeasurement(resultValue, resultValueType)
val resultValue = ds.select(columnExpression).collect()
ResultOfMeasurement(resultValue(0).toString, resultValueType)
}
}
object RecordCount extends MeasureType {
def apply(measuredColumn: String): RecordCount = RecordCount(measuredColumn, measureName, resultValueType)

override val measureName: String = "count"
override def controlColumns: Seq[String] = Seq.empty
override val resultValueType: ResultValueType.ResultValueType = ResultValueType.Long
}
object RecordCount {
private[agent] val measureName: String = "count"
def apply(): RecordCount = RecordCount(measureName)
}

case class DistinctRecordCount private (measureName: String, controlCols: Seq[String]) extends AtumMeasure {
require(controlCols.nonEmpty, "At least one control column has to be defined.")

case class DistinctRecordCount private (
measuredColumn: String,
measureName: String,
resultValueType: ResultValueType.ResultValueType
) extends Measure {
private val columnExpression = countDistinct(col(controlCols.head), controlCols.tail.map(col): _*)

override def function: MeasurementFunction =
(ds: DataFrame) => {
val resultValue = ds.select(col(measuredColumn)).distinct().count().toString
ResultOfMeasurement(resultValue, resultValueType)
val resultValue = ds.select(columnExpression).collect()
ResultOfMeasurement(resultValue(0)(0).toString, resultValueType)
}
}
object DistinctRecordCount extends MeasureType {
def apply(measuredColumn: String): DistinctRecordCount = {
DistinctRecordCount(measuredColumn, measureName, resultValueType)
}

override val measureName: String = "distinctCount"
override def controlColumns: Seq[String] = controlCols
override val resultValueType: ResultValueType.ResultValueType = ResultValueType.Long
}
object DistinctRecordCount {
private[agent] val measureName: String = "distinctCount"
def apply(controlCols: Seq[String]): DistinctRecordCount = DistinctRecordCount(measureName, controlCols)
}

case class SumOfValuesOfColumn private (
measuredColumn: String,
measureName: String,
resultValueType: ResultValueType.ResultValueType
) extends Measure {
case class SumOfValuesOfColumn private (measureName: String, controlCol: String) extends AtumMeasure {
private val columnAggFn: Column => Column = column => sum(column)

override def function: MeasurementFunction = (ds: DataFrame) => {
val aggCol = sum(col(valueColumnName))
val resultValue = aggregateColumn(ds, measuredColumn, aggCol)
ResultOfMeasurement(resultValue, resultValueType)
}
}
object SumOfValuesOfColumn extends MeasureType {
def apply(measuredColumn: String): SumOfValuesOfColumn = {
SumOfValuesOfColumn(measuredColumn, measureName, resultValueType)
val dataType = ds.select(controlCol).schema.fields(0).dataType
val resultValue = ds.select(columnAggFn(castForAggregation(dataType, col(controlCol)))).collect()
ResultOfMeasurement(handleAggregationResult(dataType, resultValue(0)(0)), resultValueType)
}

override val measureName: String = "aggregatedTotal"
override def controlColumns: Seq[String] = Seq(controlCol)
override val resultValueType: ResultValueType.ResultValueType = ResultValueType.BigDecimal
}
object SumOfValuesOfColumn {
private[agent] val measureName: String = "aggregatedTotal"
def apply(controlCol: String): SumOfValuesOfColumn = SumOfValuesOfColumn(measureName, controlCol)
}

case class AbsSumOfValuesOfColumn private (
measuredColumn: String,
measureName: String,
resultValueType: ResultValueType.ResultValueType
) extends Measure {
case class AbsSumOfValuesOfColumn private (measureName: String, controlCol: String) extends AtumMeasure {
private val columnAggFn: Column => Column = column => sum(abs(column))

override def function: MeasurementFunction = (ds: DataFrame) => {
val aggCol = sum(abs(col(valueColumnName)))
val resultValue = aggregateColumn(ds, measuredColumn, aggCol)
ResultOfMeasurement(resultValue, resultValueType)
}
}
object AbsSumOfValuesOfColumn extends MeasureType {
def apply(measuredColumn: String): AbsSumOfValuesOfColumn = {
AbsSumOfValuesOfColumn(measuredColumn, measureName, resultValueType)
val dataType = ds.select(controlCol).schema.fields(0).dataType
val resultValue = ds.select(columnAggFn(castForAggregation(dataType, col(controlCol)))).collect()
ResultOfMeasurement(handleAggregationResult(dataType, resultValue(0)(0)), resultValueType)
}

override val measureName: String = "absAggregatedTotal"
override val resultValueType: ResultValueType.ResultValueType = ResultValueType.Double
override def controlColumns: Seq[String] = Seq(controlCol)
override val resultValueType: ResultValueType.ResultValueType = ResultValueType.BigDecimal
}
object AbsSumOfValuesOfColumn {
private[agent] val measureName: String = "absAggregatedTotal"
def apply(controlCol: String): AbsSumOfValuesOfColumn = AbsSumOfValuesOfColumn(measureName, controlCol)
}

case class SumOfHashesOfColumn private (
measuredColumn: String,
measureName: String,
resultValueType: ResultValueType.ResultValueType
) extends Measure {

case class SumOfHashesOfColumn private (measureName: String, controlCol: String) extends AtumMeasure {
private val columnExpression: Column = sum(crc32(col(controlCol).cast("String")))
override def function: MeasurementFunction = (ds: DataFrame) => {

val aggregatedColumnName = ds.schema.getClosestUniqueName("sum_of_hashes")
val value = ds
.withColumn(aggregatedColumnName, crc32(col(measuredColumn).cast("String")))
.agg(sum(col(aggregatedColumnName)))
.collect()(0)(0)
val resultValue = if (value == null) "" else value.toString
ResultOfMeasurement(resultValue, ResultValueType.String)
}
}
object SumOfHashesOfColumn extends MeasureType {
def apply(measuredColumn: String): SumOfHashesOfColumn = {
SumOfHashesOfColumn(measuredColumn, measureName, resultValueType)
val resultValue = ds.select(columnExpression).collect()
ResultOfMeasurement(Option(resultValue(0)(0)).getOrElse("").toString, resultValueType)
}

override val measureName: String = "hashCrc32"
override def controlColumns: Seq[String] = Seq(controlCol)
override val resultValueType: ResultValueType.ResultValueType = ResultValueType.String
}
object SumOfHashesOfColumn {
private[agent] val measureName: String = "hashCrc32"
def apply(controlCol: String): SumOfHashesOfColumn = SumOfHashesOfColumn(measureName, controlCol)
}

/**
* This method aggregates a column of a given data frame using a given aggregation expression.
* The result is converted to a string.
*
* @param df A data frame
* @param measureColumn A column to aggregate
* @param aggExpression An aggregation expression
* @return A string representation of the aggregated value
*/
private def aggregateColumn(
df: DataFrame,
measureColumn: String,
aggExpression: Column
): String = {
val dataType = df.select(measureColumn).schema.fields(0).dataType
val aggregatedValue = dataType match {
private def castForAggregation(
dataType: DataType,
column: Column
): Column = {
dataType match {
case _: LongType =>
// This is protection against long overflow, e.g. Long.MaxValue = 9223372036854775807:
// scala> sc.parallelize(List(Long.MaxValue, 1)).toDF.agg(sum("value")).take(1)(0)(0)
// res11: Any = -9223372036854775808
// Converting to BigDecimal fixes the issue
// val ds2 = ds.select(col(measurement.measuredColumn).cast(DecimalType(38, 0)).as("value"))
// ds2.agg(sum(abs($"value"))).collect()(0)(0)
val ds2 = df.select(
col(measureColumn).cast(DecimalType(38, 0)).as(valueColumnName)
)
val collected = ds2.agg(aggExpression).collect()(0)(0)
if (collected == null) 0 else collected
column.cast(DecimalType(38, 0))
case _: StringType =>
// Support for string type aggregation
val ds2 = df.select(
col(measureColumn).cast(DecimalType(38, 18)).as(valueColumnName)
)
val collected = ds2.agg(aggExpression).collect()(0)(0)
column.cast(DecimalType(38, 18))
case _ =>
column
}
}

private def handleAggregationResult(dataType: DataType, result: Any): String = {
val aggregatedValue = dataType match {
case _: LongType =>
if (result == null) 0 else result
case _: StringType =>
val value =
if (collected == null) new java.math.BigDecimal(0)
else collected.asInstanceOf[java.math.BigDecimal]
if (result == null) new java.math.BigDecimal(0)
else result.asInstanceOf[java.math.BigDecimal]
value.stripTrailingZeros // removes trailing zeros (2001.500000 -> 2001.5, but can introduce scientific notation (600.000 -> 6E+2)
.toPlainString // converts to normal string (6E+2 -> "600")
case _ =>
val ds2 = df.select(col(measureColumn).as(valueColumnName))
val collected = ds2.agg(aggExpression).collect()(0)(0)
if (collected == null) 0 else collected
if (result == null) 0 else result
}
// check if total is required to be presented as larger type - big decimal

workaroundBigDecimalIssues(aggregatedValue)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ object Measurement {
object MeasurementProvided {

private def handleSpecificType[T](
measure: Measure,
measure: AtumMeasure,
resultValue: T,
requiredType: ResultValueType.ResultValueType
): MeasurementProvided[T] = {
Expand All @@ -65,7 +65,7 @@ object Measurement {
* @tparam T A type of the result value.
* @return A measurement.
*/
def apply[T](measure: Measure, resultValue: T): Measurement = {
def apply[T](measure: AtumMeasure, resultValue: T): Measurement = {
resultValue match {
case l: Long =>
handleSpecificType[Long](measure, l, ResultValueType.Long)
Expand All @@ -86,11 +86,11 @@ object Measurement {
}
}

/**
* When the Atum Agent itself performs the measurements, using Spark, then in some cases some adjustments are
* needed - thus we are converting the results to strings always - but we need to keep the information about
* the actual type as well.
*/
case class MeasurementByAtum(measure: Measure, resultValue: String, resultType: ResultValueType.ResultValueType)
extends Measurement
/**
* When the Atum Agent itself performs the measurements, using Spark, then in some cases some adjustments are
* needed - thus we are converting the results to strings always - but we need to keep the information about
* the actual type as well.
*/
case class MeasurementByAtum(measure: AtumMeasure, resultValue: String, resultType: ResultValueType.ResultValueType)
extends Measurement
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ private [agent] object MeasurementBuilder {

private [agent] def buildMeasurementDTO(measurement: Measurement): MeasurementDTO = {
val measureName = measurement.measure.measureName
val measuredColumns = Seq(measurement.measure.measuredColumn)
val measureDTO = MeasureDTO(measureName, measuredColumns)

val measureDTO = MeasureDTO(measureName, measurement.measure.controlColumns)
val measureResultDTO = MeasureResultDTO(TypedValue(measurement.resultValue.toString, measurement.resultType))

MeasurementDTO(measureDTO, measureResultDTO)
Expand Down
Loading

0 comments on commit 4057443

Please sign in to comment.