In [0]:
import json
import pprint
import traceback
from pyspark.sql.functions import col, lower
from concurrent.futures import ThreadPoolExecutor, as_completed

In [0]:
class DeIdentification:
    def __init__(self, p_source_catalog, p_source_schema, p_source_table_list, p_target_catalog, p_target_schema, p_mode='overwrite', p_create_if_not_exists=True):
        self.params_dict = {
            'source_catalog': p_source_catalog,
            'source_schema': p_source_schema,
            'source_table_list': p_source_table_list,
            'target_catalog': p_target_catalog,
            'target_schema': p_target_schema,
            'mode': p_mode,
            'create_if_not_exists': p_create_if_not_exists
        }

        self.process_max_workers = 5

    def set_process_max_workers(self, process_max_workers):
        self.process_max_workers = process_max_workers

    def get_process_max_workers(self):
        return self.process_max_workers
    
    def process(self):

        v_thread_errors = []       # To collect all exceptions
        v_thread_results = []      # To collect successful result
        v_return_dict = {'Total': 0, 'ExecutionFailed' : 0, 'Success': 0}
        print("Processing...")
        table_array = []
        for table in self.params_dict['source_table_list']:
            #print(f"Processing table {table}...")

            table_dict = {}
            table_dict['source_table'] = self.params_dict['source_catalog'] + '.' + self.params_dict['source_schema'] + '.' + table['source_table']
            table_dict['target_table'] =  self.params_dict['target_catalog'] + '.' +  self.params_dict['target_schema'] + '.' + table['source_table']
            table_dict['target_mode'] =  table['mode'] if 'mode' in table else self.params_dict['mode']
            table_dict['target_create_if_not_exists'] = table['create_if_not_exists'] if 'create_if_not_exists' in table else self.params_dict['create_if_not_exists']

            table_dict['keys'] = table['keys'] if 'keys' in table else None

            table_dict['ProcessIdentifier'] = table_dict['source_table'] + '|' + table_dict['target_table']

            table_array.append(table_dict)

        v_return_dict['Total'] = len(table_array)

        if len(table_array) > 0:
            with ThreadPoolExecutor(max_workers=self.get_process_max_workers()) as executor:
                futures = {executor.submit(self.process_table, obj): obj for obj in table_array}
                for future in as_completed(futures):
                    obj = futures[future]
                    try:
                        result = future.result()
                        v_thread_results.append(str(obj['ProcessIdentifier']) + ' \nMessage=Data de-identified and copied.')
                        print(f"Thread Successes: ", v_thread_results)
                    except Exception as e:
                        v_thread_errors.append(str(obj['ProcessIdentifier']) + ' \nMessage=' + str(e) + ' \nTraceback=' + traceback.format_exc().replace('\n',''))
                        print(f"Thread Errors: ", v_thread_errors)

        v_return_dict['Success'] = len(v_thread_results)
        v_return_dict['ExecutionFailed'] = len(v_thread_errors)
        v_return_dict['ExecutionFailures'] = v_thread_errors
        v_return_dict['SuccessResults'] = v_thread_results

        print("\n\nSummary:\n")
        pprint.pprint(v_return_dict, indent=4)

        if len(v_thread_errors) > 0:
            print(f"Thread Errors: ", v_thread_errors)
            raise Exception(v_return_dict)
        
        return v_return_dict    

    def process_table(self, table_dict):
        return self.DataMaskingTable(table_dict['source_table'], table_dict['target_table'], table_dict['target_mode'], table_dict['keys'], None, table_dict['target_create_if_not_exists'])

    def DataMaskingTableValidation(self, p_source_table, p_target_table, p_target_mode, p_keys, p_rule_table, p_create_target_if_not_exists):
    
        return_dict = {'validation_message': None, 'source_table_catalog' : None, 'source_table_schema': None, 'source_table': None, 'target_table_exists': False, 'p_rule_table' : None}
        validation_message = ''
        

        if p_rule_table is None:
            p_rule_table = return_dict['p_rule_table'] = 'db_config.cfgdatamasking2'

        if not spark.catalog.tableExists(p_rule_table):
            validation_message += f"Rule table {p_rule_table} does not exist.\n"

        if not p_target_table or (not p_create_target_if_not_exists and not spark.catalog.tableExists(p_target_table)):
            validation_message += f"Target table {p_target_table} does not exist.\n"
        else:        
            return_dict['target_table_exists'] = spark.catalog.tableExists(p_target_table)

        if p_target_mode.lower() not in ['merge', 'overwrite', 'append']:
            validation_message += f"Target mode {p_target_mode} is not supported. Please use 'merge', 'overwrite', or 'append'."

        if p_target_mode.lower() == 'merge' and not p_keys:
            validation_message += f"Target mode {p_target_mode} requires keys to be specified as a list or a string (comma separate).\n"    

        if spark.catalog.tableExists(p_source_table):
            source_table_split_array = p_source_table.split('.')
            if len(source_table_split_array) == 3:
                return_dict['source_table_catalog'] = source_table_split_array[0] 
                return_dict['source_table_schema'] = source_table_split_array[1]
                return_dict['source_table'] = source_table_split_array[2]   
            elif len(source_table_split_array) == 2:
                return_dict['source_table_schema'] = source_table_split_array[0]
                return_dict['source_table'] = source_table_split_array[1]  
            else:
                pass       
            #print(f"source_table_catalog: {source_table_catalog}, source_table_schema: {source_table_schema}, source_table: {source_table}, rule_table: {p_rule_table}")
        else:
            validation_message += f"Source table {p_source_table} does not exist.\n"

        # Source and Target Table should be different
        if p_source_table == p_target_table:
            validation_message += f"Source table {p_source_table} and Target table {p_target_table} are the same.\n"  

        return_dict['validation_message'] = validation_message
        
        return return_dict    

    def DataMaskingTable(self, p_source_table, p_target_table, p_target_mode, p_keys=None, p_rule_table=None, p_create_target_if_not_exists=False):
        """
        This function is used to mask the data in the target table based on the rules in the rule table.
        The rule table is a dataframe with the following columns:
        - column_name: the name of the column to be masked
        - rule: the rule to be applied to the column
        - value: the value to be used in the rule
        """
        return_dict = {'status': 'failed', 'message': None, 'source_table': p_source_table, 'target_table': p_target_table,
                    'num_affected_rows': None, 'num_updated_rows': None, 'num_inserted_rows': None, 'num_deleted_rows': None}

        validation_return_dict = self.DataMaskingTableValidation(p_source_table, p_target_table, p_target_mode, p_keys, p_rule_table, p_create_target_if_not_exists)

        print(validation_return_dict)

        if validation_return_dict['validation_message']:
            return_dict['message'] = validation_return_dict['validation_message']
            raise Exception(validation_return_dict['validation_message'])
        else:

            p_rule_table = validation_return_dict['p_rule_table']
            source_table_schema = validation_return_dict['source_table_schema']
            source_table = validation_return_dict['source_table']
            source_table_catalog = validation_return_dict['source_table_catalog']
            target_table_exists = validation_return_dict['target_table_exists']
            
            
            #Get List of Columns from Source Table
            source_table_columns_list = spark.sql(f"select * from {p_source_table} limit 0").columns
            source_table_columns_dict = {column.lower(): column for column in source_table_columns_list}
            source_table_columns_dict_keys = list(source_table_columns_dict.keys())
            print("\nDict Keys", source_table_columns_dict_keys)
            print("\nInitial Column Dict:", source_table_columns_dict)
            sql_text = f"""  select *, row_number() over (partition by DataMaskingColumn order by priority asc) as row_priority
                                            from
                                            (
                                                select DataMaskingColumn, DataMaskingString, DataMaskingSchema, DataMaskingTable, 
                                                CASE 
                                                    WHEN POSITION('%' IN DataMaskingSchema) == 0 and POSITION('%' IN DataMaskingTable) == 0 THEN 1
                                                    WHEN POSITION('%' IN DataMaskingSchema) > 0 and POSITION('%' IN DataMaskingTable) == 0 THEN 2
                                                    WHEN POSITION('%' IN DataMaskingSchema) == 0 and POSITION('%' IN DataMaskingTable) > 0 THEN 3
                                                    WHEN POSITION('%' IN DataMaskingSchema) > 0 and POSITION('%' IN DataMaskingTable) > 0 THEN 4
                                                    ELSE 5
                                                END as priority
                                                from {p_rule_table} 
                                                where '{source_table_schema}' ilike DataMaskingSchema and '{source_table}' ilike DataMaskingTable
                                            )
                                    """
        
            df_rule_table = spark.sql(sql_text).filter((lower(col('DataMaskingColumn')).isin(source_table_columns_dict_keys)) & (col('row_priority') == 1))
            print(sql_text)
            df_rule_table.show(truncate=False) 
            rule_dict = {row['DataMaskingColumn'].lower(): row['DataMaskingString'] for row in df_rule_table.collect()}
            print("\nRule Column Dict:", rule_dict)
            source_table_columns_dict.update(rule_dict)
            print("\nFinal Column Dict:",source_table_columns_dict)
            source_table_select_str = 'select ' + ', '.join(source_table_columns_dict.values()) + ' from {0}'.format(p_source_table)
            print("\nSourece Table Select String", source_table_select_str)

            if validation_return_dict['target_table_exists'] == False and p_create_target_if_not_exists == True:
                sql_text = f"""CREATE TABLE IF NOT EXISTS {p_target_table} AS SELECT * FROM {p_source_table} where 1=0;"""
                print(sql_text)
                spark.sql(sql_text)

            if p_target_mode.lower() == 'merge':
                source_table_keys_dict = {column: column for column in p_keys} if isinstance(p_keys, list) else {column: column for column in p_keys.split(',')}
                merge_sql_str = "Merge into {0} as target using ({1}) as source on {2} when matched then update set {3} when not matched then insert ({4}) values ({5})"
                merge_sql_str = merge_sql_str.format(p_target_table, 
                                                    source_table_select_str, 
                                                    ' and '.join([f"target.{column} = source.{column}" for column in source_table_keys_dict]), 
                                                    ', '.join([f"target.{column} = source.{column}" for column in source_table_columns_dict]), 
                                                    ', '.join(source_table_columns_dict), 
                                                    ', '.join(source_table_columns_dict.keys())
                                                    )
                print("\nMerge String", merge_sql_str)
                df_merge = spark.sql(merge_sql_str)
                df_merge.show()
                df_merge_result = df_merge.collect()
                
                return_dict['num_affected_rows'] = df_merge_result[0]['num_affected_rows']
                return_dict['num_updated_rows'] = df_merge_result[0]['num_updated_rows']
                return_dict['num_inserted_rows'] = df_merge_result[0]['num_inserted_rows']
                return_dict['num_deleted_rows'] = df_merge_result[0]['num_deleted_rows']
                return_dict['status'] = 'success'
            elif p_target_mode.lower() in ['append', 'overwrite']:
                df_target = spark.sql(source_table_select_str)
                df_target.write.format("delta").mode(p_target_mode).saveAsTable(p_target_table)
                return_dict['status'] = 'success'
            
                
        return return_dict

In [0]:
# %sql
# select 
# --SourceServerName2 as source_catalog, 
# distinct
# lower(StepName) as source_table, 
# CDCKeyColumns as keys,
# 'merge' as mode,
#  1 as create_if_not_exists
# from [CFG].[vw_TableProcessList] where SourceServerName1 = 'LEWVPXAPDBRP15'
# and SourceServerName2 = 'clm_ods_lewvpxapdbrp15_cdc_04'
# FOR JSON AUTO