In [0]:
%pip install pyyaml -i https://artifactory.healthpartners.com/artifactory/api/pypi/python-hosted-remote/simple
import warnings
import pyspark.sql
import Databricks.SharedModules.notification_utils as notif
import yaml
from pyspark.sql.functions import col, when
from typing import Any, Optional, List
from datetime import date, datetime, timedelta
from Databricks.SharedModules.data_quality_exceptions import *
from Databricks.SharedModules.delta_table_utils import delta_table_record_count

In [0]:
class DataQualityWarning(Warning):
    """ 
        Custom class for DataQualityWarnings. This allows notification logic 
        to send messages when there's a warning related to data quality but 
        not when there's a deprecation warning or something like that.
    """
    def __init__(self, message):
        self.message = message


In [0]:
def alert(warning, full_table_name, source_environment, target_environment):
    """ 
        Sends a Teams alert to the relevant webhook when called.
    """
    if source_environment == 'prd' and target_environment == 'prd':
        subtitle = 'PRODUCTION'
        webhook = 'https://healthpartnersconnect.webhook.office.com/webhookb2/8e6d2cfc-f3ff-4678-ab5e-f98806ce95ef@9539230a-5213-4542-9ca6-b0ec58c41a4d/IncomingWebhook/c54f59cb1b5244d7a50a0ba73798933e/f8558c21-1856-4b0f-b1f3-e0e09dbd7eda'
    else:
        subtitle = 'Non-Prod'
        webhook = 'https://healthpartnersconnect.webhook.office.com/webhookb2/8e6d2cfc-f3ff-4678-ab5e-f98806ce95ef@9539230a-5213-4542-9ca6-b0ec58c41a4d/IncomingWebhook/888eb648756b446f98ec5061d8b90168/f8558c21-1856-4b0f-b1f3-e0e09dbd7eda'

In [0]:
def source_dataframe_populated(extract_sdf: pyspark.sql.DataFrame,
                               full_table_name: str,
                                source_environment: str,
                                target_environment: str,
                                warning_num_obs: Optional[int] = None,
                                error_num_obs: Optional[int] = 1) -> int:
    """ Sends a warning if the number of records in the source pyspark dataframe
        is less than warning_num_obs. Raises an error if less than error_num_obs.

        Parameters
        ----------
        extract_sdf: pyspark.sql.Dataframe
                The source data extract as a pyspark dataframe
        warning_num_obs: int, default None
                A warning is logged if the number of records in source_sdf is 
                less than warning_num_obs
        error_num_obs: int, default 1
                An error is raised if the number of records in source_sdf is
                less than error_num_obs
        
        Raises
        ----------
        SourceDataframeNotPopulatedError
                If fewer than error_num_obs records in source_sdf.
        
        Returns
        ----------
        extract_num_obs: int
                Number of records in the source_sdf. Can use if trying to avoid
                counting number of records multiple times.
    """

    extract_num_obs = extract_sdf.count()

    if extract_num_obs < error_num_obs:
        print("***ERROR***")
        alert(f"source_dataframe_populated failure: number of source records is less than the error_threshold ({error_num_obs})", full_table_name, source_environment, target_environment)
        raise SourceDataframeNotPopulatedError(error_num_obs=error_num_obs,
                                               dataframe_num_obs=extract_num_obs)
    elif warning_num_obs is not None and (extract_num_obs < warning_num_obs):
        print("***WARNING***")
        warnings.warn(f'Only {extract_num_obs} records were detected in the source dataframe. This is less than the warning_threshold ({warning_num_obs}).',
                      DataQualityWarning)
    else:
        print("source_dataframe_populated SUCCESS")

    return extract_num_obs


In [0]:
def data_completeness(catalog: str, schema: str, table: str,
                      to_load_sdf: pyspark.sql.DataFrame,
                      spark: pyspark.sql.SparkSession,
                      source_environment: str,
                      target_environment: str,
                      warning_threshold: Optional[float] = 0.75,
                      error_threshold: Optional[float] = 0.05) -> None:
    """ Sends errors or warnings if number of records in data to be loaded is less 
        than a given percentage of the existing records in the target table. This check
        makes sense only for upsert or trunc and load strategy, not insert-only.

        Parameters
        ----------
        catalog: str
                Name of the catalog where the target table is stored (e.g. 'cleansed')
        schema: str
                Name of the target table's schema (e.g. 'care_group')
        table: str
                Name of the target table (e.g. 'emr_patient_identity')
        to_load_sdf: pyspark.sql.DataFrame
                The pyspark dataframe we're planning to load into the target table
        spark: pyspark.sql.SparkSession
                The active spark session. In python notebooks, this variable is 
                automatically created and named spark.
        warning_threshold: Optional, float between 0 and 1
                Send a warning if fewer than threshold * existing number of records
                found in data to be loaded. Set to None to not send warnings.
        error_threshold: Optional, float between 0 and 1
                Raise an error if fewer than threshold * existing number of records
                found in data to be loaded. Set to None to not raise errors.
    """
    target_table_count = delta_table_record_count(catalog=catalog, schema=schema,
                                            table=table, spark=spark)
    load_sdf_count = to_load_sdf.count()

    if error_threshold is not None and (load_sdf_count < error_threshold * target_table_count):
        print("***ERROR***")        
        alert(f"data_completeness failure: number of source records is less than {error_threshold * 100}% of records already in {full_table_name})", full_table_name, source_environment, target_environment)
        raise DataCompletenessError(load_sdf_count=load_sdf_count, target_table_count=target_table_count, table=table)
    elif warning_threshold is not None and (load_sdf_count < warning_threshold * target_table_count):
        print("***WARNING***")
        warnings.warn('Only {load_sdf_count} records to be added to {table}, but {target_table_count} records currently in table. Check for possible data quality issue on load.'.format(
            load_sdf_count=load_sdf_count,
            table=table,
            target_table_count=target_table_count
            ),
            DataQualityWarning)
    else:
        print("data_completeness SUCCESS")

In [0]:
def key_validation(sdf: pyspark.sql.DataFrame,
                   full_table_name: str,
                   source_environment: str, 
                   target_environment: str,
                    key_columns: List[str]) -> None:
    """ Verifies that the pyspark dataframe is not duplicated by the key_columns.
        Use to validate the primary and natural keys.

        Note: This can also be checked with great expectations and if you have a 
            great expectations test suite in your project, that is the recommended 
            route.

        Parameters
        ----------
        sdf: pyspark.sql.DataFrame
                The pyspark dataframe to check for bad dates
        key_columns: List[str]
                A list of the column names by which the data should be unique
        
        Raises
        ----------
        KeyViolationError
                If the pyspark dataframe is not unique by the key_columns,
                then a KeyViolationError is raised.
        
        Returns
        ----------
        None
    """
    if sdf.count() != sdf.select(key_columns).distinct().count():
        alert(f"key_validation failure: duplicate row(s) exist across primary key columns", full_table_name, source_environment, target_environment)
        raise KeyViolationError(key_columns)
    else:
        print("key_validation SUCCESS")

In [0]:
def date_range_validation(sdf: pyspark.sql.DataFrame,
                          date_columns: List[str],
                          bad_dates_allowed: int,
                          min_allowable_dt: datetime.date = (datetime.now() - timedelta(days=365*200)).date(),
                          max_allowable_dt: datetime.date = datetime.now().date()):
                        #   error_max_replacements: Optional[int] = None) -> pyspark.sql.DataFrame:
    """ If a date column is outside of the range defined by min_allowable_dt
        and max_allowable_dt and not equal to 9999-01-01, then the value is replaced by
        9999-01-01 and a warning is issued.

        Parameters
        ----------
        sdf: pyspark.sql.DataFrame
                The pyspark dataframe to check for bad dates
        date_columns: List[str]
                A list of the column names to check for dates outside the range
        min_allowable_dt: datetime.date, default 200 years ago (give or take a day)
                Dates less than min_allowable_dt will be replaced with a warning
        max_allowable_dt: datetime.date, default today
                Dates greater than max_allowable_dt will be replaced with a warning
        error_max_replacements: Optional[int], default None
                If more than error_max_replacements occur in a single column,
                an exception is raised. This field is intended to catch problems
                such as an entire column of bad dates for which we actually want
                to kill the job rather than replace the values.
        
        Raises
        ----------
        DateRangeViolationError
                If more than error_max_replacements records have values that are
                outside the allowable date range, a DataRangeViolationError is 
                raised.
        
        Returns
        ----------
        sdf: pyspark.sql.DataFrame
                The dataframe following bad date replacement
    """
    date_replacement = datetime.strptime('9999-01-01', '%Y-%m-%d').date()
    print()
    # Replace out-of-range values with '9999-01-01'
    for dc in date_columns:
        # bad_date_count is # of dates either not in allowed range OR not equal to replacement date 
        bad_date_count = sdf.filter((col(dc) < min_allowable_dt) | (col(dc) > max_allowable_dt)).count()
        print(f" current column: {dc}")
        print(f" bad_date_count: {bad_date_count}")

        # if # bad dates is greater than max replacement value, raise error
        if bad_date_count > bad_dates_allowed:
            raise DateRangeValidationError(column_name=dc, bad_date_count=bad_date_count,
                                            error_max_replacements=bad_dates_allowed)
        elif bad_date_count > 0:
            print("***WARNING***")
            warnings.warn('{bad_date_count} invalid dates were detected in the {dc} column. This is less than the allowed number {bad_dates_allowed} so these dates were ignored'.format(
                    bad_date_count=bad_date_count,
                    dc=dc,
                    bad_dates_allowed=bad_dates_allowed
                    ),
                    DataQualityWarning)
    print("date_range_validation SUCCESS")
        # if error_max_replacements is not None and bad_date_count > error_max_replacements:
        #     raise DateRangeValidationError(column_name=dc, bad_date_count=bad_date_count,
        #                                     error_max_replacements=error_max_replacements)
        # # if # bad dates is NOT greater than max replacement value and is not zero, replace bad dates with replacement & send warning with # of dates replaced
        # elif bad_date_count > 0:
        #     sdf = sdf.withColumn(dc, when((col(dc) != date_replacement) &
        #                                     (col(dc) < min_allowable_dt) | (col(dc) > max_allowable_dt),
        #                                 date_replacement).otherwise(col(dc)))
        #     warnings.warn('{bad_date_count} values were replaced with {replacement} in the {dc} column.'.format(
        #             bad_date_count=bad_date_count,
        #             replacement=date_replacement,
        #             dc=dc
        #             ),
        #             DataQualityWarning)
    return sdf


In [0]:
# def null_value_replacement(sdf: pyspark.sql.DataFrame,
#                         non_null_columns: List[str],
#                         error_max_replacements: Optional[int] = None) -> pyspark.sql.DataFrame:
#         """ If a null value is detected in a column required to be non-null,
#                 a warning is sent and the value is replaced. Note that this data
#                 quality check is intended to serve as a backup measure for unexpected
#                 null data coming in from the source and should not be a first approach
#                 to null value replacement. Additionally, these checks are not a
#                 replacement for appropriately identifying non-null columns in your DDL
#                 statement and consequently in your table's metadata.

#                 Numeric ID 9999999999 (ten 9's), Dates and Times:  9999-01-01,
#                 Strings: 'Not Available',  Y/N:  'N'

#                 Data types are derived from column names and types.

#                 Parameters
#                 ----------
#                 sdf: pyspark.sql.DataFrame
#                         The pyspark dataframe to check for nulls
#                 non_null_columns: List[str]
#                         A list of the column names to check for null values
#                 error_max_replacements: Optional[int], default None
#                         If more than error_max_replacements occur in a single column,
#                         an error is raised.
                
#                 Raises
#                 ----------
#                 NullValueReplacementError
#                         If more than error_max_replacements records have updates to a
#                         single column OR there are any nulls in a numeric column without
#                         suffix _id then a NullValueReplacementError is raised.
                
#                 Returns
#                 ----------
#                 sdf: pyspark.sql.DataFrame
#                         The dataframe following null value replacement
#         """

#         # Define replacement values for different data types
#         id_replacement = 9999999999
#         timestamp_replacement = datetime.strptime('9999-01-01', '%Y-%m-%d')
#         date_replacement = timestamp_replacement.date()
#         string_replacement = 'Not Available'
#         yn_replacement = 'N'


#         # Replace null values with appropriate replacement values
#         for nnc in non_null_columns:
#                 null_count = sdf.filter(col(nnc).isNull()).count()

#                 if error_max_replacements is not None and null_count > error_max_replacements:
#                 raise NullValueReplacementError(column_name=nnc, null_value_count=null_count,
#                                                 threshold_error=True, error_max_replacements=error_max_replacements)
#                 elif null_count > 0:
#                 if nnc[-3:] == '_id' and sdf.schema[nnc].dataType.simpleString() in ['int', 'bigint']:
#                         sdf = sdf.withColumn(nnc, when(col(nnc).isNull(), id_replacement).otherwise(col(nnc)))
#                         warnings.warn('{null_count} values were replaced with {replacement} in the {nnc} column.'.format(
#                         null_count=null_count,
#                         replacement=id_replacement,
#                         nnc=nnc
#                         ),
#                         DataQualityWarning)
#                 elif nnc[-3:] == '_yn':
#                         sdf = sdf.withColumn(nnc, when(col(nnc).isNull(), yn_replacement).otherwise(col(nnc)))
#                         warnings.warn('{null_count} values were replaced with {replacement} in the {nnc} column.'.format(
#                         null_count=null_count,
#                         replacement=yn_replacement,
#                         nnc=nnc
#                         ),
#                         DataQualityWarning)
#                 elif sdf.schema[nnc].dataType.simpleString() == 'date':
#                         sdf = sdf.withColumn(nnc, when(col(nnc).isNull(), date_replacement).otherwise(col(nnc)))
#                         warnings.warn('{null_count} values were replaced with {replacement} in the {nnc} column.'.format(
#                         null_count=null_count,
#                         replacement=date_replacement,
#                         nnc=nnc
#                         ),
#                         DataQualityWarning)
#                 elif sdf.schema[nnc].dataType.simpleString() == 'timestamp':
#                         sdf = sdf.withColumn(nnc, when(col(nnc).isNull(), timestamp_replacement).otherwise(col(nnc)))
#                         warnings.warn('{null_count} values were replaced with {replacement} in the {nnc} column.'.format(
#                         null_count=null_count,
#                         replacement=timestamp_replacement,
#                         nnc=nnc
#                         ),
#                         DataQualityWarning)
#                 elif sdf.schema[nnc].dataType.simpleString() == 'string':
#                         sdf = sdf.withColumn(nnc, when(col(nnc).isNull(), string_replacement).otherwise(col(nnc)))
#                         warnings.warn('{null_count} values were replaced with {replacement} in the {nnc} column.'.format(
#                         null_count=null_count,
#                         replacement=string_replacement,
#                         nnc=nnc
#                         ),
#                         DataQualityWarning)
#                 else:
#                         raise NullValueReplacementError(column_name=nnc, null_value_count=null_count, threshold_error=False)

#         return sdf
        