In [0]:
def check_duplicates(df, id):
    groupped_df = df.groupBy(id).count()
    ids = groupped_df.filter(groupped_df['count'] > 1).select(id).collect()
    return df.filter(df[id].isin(ids)) \
        .withColumn('DataQuality', lit('Duplicate in Key Column'))
                    
def check_to_nulls(df, column):
    filtered_df = df.filter(df[column].isNull())
    return filtered_df.withColumn('DataQuality', lit('Null in Column'))

def check_for_valid_state(df, column, states):
    filtered_df = df.filter(~df[column].isin(states))
    return filtered_df.withColumn('DataQuality', concat(lit('Invalid State of'), df[column]))

def check_for_email_format(df, colm):
    email_regex = r"^[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}$"

    df_validated = df.filter(~df[colm].rlike(email_regex))
    df_validated = df_validated.withColumn('DataQuality', concat(lit('Invalid Email '), df_validated[colm]))
    return df_validated

In [0]:
def check_logical_date(df, start_date, end_date):
    return df.filter(to_date(df[start_date]) > to_date(df[end_date])) \
        .withColumn('DataQuality', concat(lit('Invalid Date of '), df[start_date], lit(' and '), df[end_date]))

In [0]:
from cryptography.fernet import Fernet
from pyspark.sql.functions import col, udf
from pyspark.sql.types import StringType



secret_key = 'Tg-Bs_QsJhsFCFMb9uK1J2jOhV7VEUyh2-UtrchG5Ow='
# dbutils.secrets.get(scope = "key-vault-secret", key = "fernet_key_name")
# secret_key = Fernet.generate_key()
cipher = Fernet(secret_key)

def encrypt_value(val):
    if val:
        val = cipher.encrypt(val.encode()).decode()
    return udf(val, StringType())

def decrypt_value(val):
    if val:
        val = cipher.decrypt(val.encode()).decode()
    return udf(val, StringType())


In [0]:
def write_to_data_lake(baseDF: DataFrame, wrtFormat: StringType, 
                        wrtMode: StringType, wrtPath: StringType, wrtKey: StringType):
    if wrtMode == "overwrite":
        
        baseDF.write.format(wrtFormat) \
            .mode("overwrite")\
            .option("overwriteSchema", "true") \
            .save(wrtPath)

    elif wrtMode == "merge":
        try:
            oldDataTable = DeltaTable.forPath(spark, wrtPath)

            oldDataTable.alias("dsOldData").merge(\
            baseDF.alias("dsNewData"),f"dsOldData.{wrtKey} = dsNewData.{wrtKey}" )\
            .whenMatchedUpdateAll()\
            .whenNotMatchedInsertAll()\
            .execute()
        except Exception as e:
            print("Exception:" , e)
    if wrtMode == "Append":
        baseDF.write.format(wrtFormat) \
            .mode("append")\
            .save(wrtPath)
            # .option("mergeSchema", True)\
    if wrtMode == 'scd2':
        baseDF_scd2 = baseDF.withColumn("is_active", lit(True)) \
                     .withColumn("start_date_eff", current_timestamp()) \
                     .withColumn("end_date_eff", lit(None).cast("timestamp")) 
        
        oldDataTable = DeltaTable.forPath(spark, wrtPath)
        merge_condition = f"dsOldData.EmpID = dsNewData.EmpID AND dsOldData.is_active = True"

        oldDataTable.alias("dsOldData").merge(
            baseDF_scd2.alias("dsNewData"),
            merge_condition
        ).whenMatchedUpdate(
            condition = " OR " + " OR ".join([f"dsOldData.{col} <> dsNewData.{col}" for col in baseDF.columns if col != 'EmpID']),
            set = {
            "is_active": lit(False),
            "end_date_eff": current_timestamp()
        }
        ).whenNotMatchedInsertAll().execute()

    return


In [0]:
def update_dw_columns(baseDF : DataFrame , lsKeyColumns: List, runId: str, dataSourceName: str) -> DataFrame:
  baseDF = (baseDF
            .withColumn("W_DATA_SOURCE",f.lit(dataSourceName))
            .withColumn("W_INSERT_DT",f.current_timestamp())
            .withColumn("W_UPDATE_DT",f.current_timestamp())
            .withColum("Run_Id", f.lit(runId))
           )
  
  baseDF = baseDF.withColumn("INTEGRATION_ID",f.concat_ws('-', *lsKeyColumns))

  return baseDF

def convert_date_columns(df):
    #only for columns that end with _DATE
    df = df.withColumn("date_len",f.lit(10))
    col_list = df.columns
    date_cols = []
    for col in col_list:
        if(col[-5:]=='_DATE'):
            if(str(df.schema[col].dataType)=='StringType()'):
                date_cols.append(col)
                df = df.withColumn('temp_date',f.left(df[col],df['date_len']))
                df = df.withColumn(col,to_date(df['temp_date'],'yyyy-MM-dd')).drop('temp_date')
    df = df.drop("date_len")
    return df