In [0]:
%run ../classhandlers/ProcessConfigHandler

In [0]:
%run ../classhandlers/UnityCatalogHandler

In [0]:
import json
import pprint
import traceback
import uuid
import time
import random
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed
from pyspark.sql.functions import col, lower, lit, current_timestamp, current_user, current_date

In [0]:
class ClientDataSplit:
    def __init__(self, p_environment,  p_internal_product_id, p_source_server_name, p_source_database_name=None, p_internal_client_id=None, p_internal_facility_id=None,p_ingestion_pipeline_name=None):
        self.params_dict = {
            'p_environment': p_environment,
            'p_internal_product_id': p_internal_product_id,
            'p_source_server_name': p_source_server_name,
            'p_source_database_name': p_source_database_name,
            'p_internal_client_id': p_internal_client_id,
            'p_internal_facility_id': p_internal_facility_id,
            'p_ingestion_pipeline_name': p_ingestion_pipeline_name
        }

        self.process_max_workers = 50
        self.trackeback_length = 1000
        self.max_retries = 3
        self.retry_delay = 60
        self.retry_delay_list = self.set_retry_delay_list()

    def set_process_max_workers(self, process_max_workers):
        self.process_max_workers = process_max_workers

    def get_process_max_workers(self):
        return self.process_max_workers
    
    def set_traceback_length(self, traceback_length):
        self.trackeback_length = traceback_length

    def get_traceback_length(self):
        return self.trackeback_length
    
    def set_max_retries(self, max_retries):
        self.max_retries = max_retries

    def get_max_retries(self):
        return self.max_retries
    
    def set_retry_delay(self, retry_delay):
        self.retry_delay = retry_delay
        self.retry_delay_list = self.set_retry_delay_list()

    def get_retry_delay(self):
        return self.retry_delay
    
    def get_random_retry_delay(self):
        return random.choice(self.retry_delay_list)
    
    def set_retry_delay_list(self):
        return [int(self.retry_delay*i) for i in [.25,.50,1,1.25,1.50]]

    def process(self):

        v_thread_errors = []       # To collect all exceptions
        v_thread_results = []      # To collect successful results
        v_validation_errors = []
        v_skipped_results = []
        v_config_row_array = []
        v_config_unique_pipelines_dict = {}
        v_pipelines_failed_list = []
        v_return_dict = {'Total': 0, 'Skipped': 0, 'ValidationFailed' : 0, 'ExecutionFailed' : 0, 'Success': 0}
        # Get Unity Catalog and Managed Location Root Path
        process_config = ProcessConfigData(self.params_dict['p_environment'])
        v_unity_catalog = process_config.get_config_attribute_value('AnalyticsUnityCatalog')
        v_managed_location_root_path = process_config.get_config_attribute_value('AdlsAnalyticsFullpathUri')
        # Load Config Data for Client Split
        df_config_data_rows = process_config.get_table_split_list(
                                                                    self.params_dict['p_internal_product_id'],
                                                                    self.params_dict['p_source_server_name'], 
                                                                    self.params_dict['p_source_database_name'],
                                                                    self.params_dict['p_internal_client_id'],
                                                                    self.params_dict['p_internal_facility_id'],
                                                                    self.params_dict['p_ingestion_pipeline_name']
                                                                )
        if df_config_data_rows is None or df_config_data_rows.count() == 0:
            raise Exception(f"Unable to find Configuration rows for {self.params_dict}")
        else:
            v_return_dict['Total'] = df_config_data_rows.count()
            #print(df_config_data_rows, "---\n---", df_config_data_rows.count())
            #df_config_data_rows.display()

        print(f"Start - Validation of {v_return_dict['Total']} config Rows....")
        for row in df_config_data_rows.collect():
            if row['Status'] == 'S':
                v_return_dict['Skipped'] += 1
                v_skipped_message = f"Warning: SKipping - Configuration row ({row['ProcessIdentifier']}) is marked as Success."
                v_skipped_results.append(v_skipped_message)
                #print(v_skipped_message)
            else:
                
                v_config_row_dict = row.asDict()
                v_config_row_dict['DestinationCatalog'] = v_unity_catalog
                v_validation_message = self.validate_config_row(v_config_row_dict)

                if v_validation_message is None:
                    v_config_row_dict['Status'] = 'Pending'
                    v_config_row_dict['DestinationCatalog'] = v_unity_catalog
                    v_config_row_dict['ManagedLocationRootPath'] = v_managed_location_root_path
                    v_config_row_dict['process_config'] = process_config
                    v_config_row_array.append(v_config_row_dict)

                    # Record Unique Pipeline Names for Status Facility Update
                    v_config_unique_pipelines_dict[v_config_row_dict['IngestionPipelineName']] = \
                                                    {   'PipelineName': v_config_row_dict['IngestionPipelineName'], \
                                                        'InternalProductId' : v_config_row_dict['InternalProductId'], \
                                                        'DataSourceId' : v_config_row_dict['DataSourceId'], \
                                                        'StepType' : 'Ods', \
                                                        'Status': 'P'
                                                    }

                    #print(f"Client Data Split will proceed for {row['ProcessIdentifier']}")
                else:
                    v_return_dict['ValidationFailed'] += 1
                    v_validation_message = f"Client Data Split will NOT proceed for {row['ProcessIdentifier']}: {v_validation_message}"
                    v_validation_errors.append(v_validation_message)
                    #print(v_validation_message)
        print("End - Validation of config Rows....")

        if len(v_config_row_array) == 0:
            v_return_dict['ValidationFailures'] = v_validation_errors
            v_return_dict['SkippedResults'] = v_skipped_results
        else:
            # Set Pipeline Status as Pending
            process_config.set_pipeline_process_status_list(v_config_unique_pipelines_dict.values())

            # Launch Threads Processing
            with ThreadPoolExecutor(max_workers=self.get_process_max_workers()) as executor:
                futures = {executor.submit(self.process_config_row, obj): obj for obj in v_config_row_array}
                for future in as_completed(futures):
                    obj = futures[future]
                    try:
                        result = future.result()
                        v_thread_results.append(str(obj['ProcessIdentifier']) + ' \nMessage=Client Data Split Successfully Processed. ' + str(result))
                        #print(f"Thread Successes: ", v_thread_results)
                    except Exception as e:
                        v_pipelines_failed_list.append(obj['IngestionPipelineName'])
                        v_thread_errors.append(str(obj['ProcessIdentifier']) + ' \nMessage=' + str(e).replace('\n','')[0:self.get_traceback_length()] + ' \nTraceback=' + traceback.format_exc().replace('\n','')[0:self.get_traceback_length()])
                        #print(f"Thread Errors: ", v_thread_errors)

            # Get Unique List of Failed Pipelines And Update Dictionary
            v_pipelines_failed_list = list(set(v_pipelines_failed_list))
            for k in v_config_unique_pipelines_dict.keys():
                v_config_unique_pipelines_dict[k]['Status'] = 'C' if k not in v_pipelines_failed_list else 'F'
            
            # Set Final Pipeline Status
            process_config.set_pipeline_process_status_list(v_config_unique_pipelines_dict.values())

        v_return_dict['Success'] = len(v_thread_results)
        v_return_dict['SuccessResults'] = [] #v_thread_results
        v_return_dict['SkippedResults'] = v_skipped_results
        v_return_dict['ValidationFailures'] = v_validation_errors
        v_return_dict['ExecutionFailed'] = len(v_thread_errors)
        v_return_dict['ExecutionFailures'] = v_thread_errors

        print("\n\nSummary:\n")
        pprint.pprint(v_return_dict, indent=4, compact=True, sort_dicts=False, width=10000)

        if len(v_thread_errors) + len(v_validation_errors) > 0:
            raise Exception(f"Execution errors: {len(v_thread_errors)} , Validation Errors: {len(v_validation_errors)}  Skipped: {len(v_skipped_results)}, Success: {len(v_thread_results)} were encountered during processing. See exception for details.")
        elif len(v_config_row_array) == 0:
            raise Exception(f"Unable to find valid runnable configuration(s) for {self.params_dict} - Skipped: {len(v_skipped_results)}")
        
        
        return v_return_dict

    def process_config_row(self, p_config_row_dict):

        print("Start - Processing: ", p_config_row_dict['ProcessIdentifier'])
        v_attempt_message = ''
        v_return_dict = {}
        v_retry_count = 0
        v_start = datetime.now()
        v_row_status_dict = {'InternalProductId': p_config_row_dict['InternalProductId'],
                             'InternalClientId': p_config_row_dict['InternalClientId'],
                             'InternalFacilityId': p_config_row_dict['InternalFacilityId'],
                             'DataSourceId': p_config_row_dict['DataSourceId'],
                             'Status': 'F',
                             'StepType': 'Ods',
                             'StepName': p_config_row_dict['DestinationTable'],
                             'PipelineName': p_config_row_dict['IngestionPipelineName']}

        # Define Audit Columns Dictionaries
        v_create_audit_dict = {'DateTimeCreated': 'current_timestamp', 'CreatedByUser': 'current_user'}
        v_update_audit_dict = {'DateTimeLastModified': 'current_timestamp', 'ModifiedByUser': 'current_user'}
        
        # Default Watermark Value if not passed....
        v_default_watermark_value = '2000-01-01'
        v_watermark_value = p_config_row_dict['WatermarkValue'] if p_config_row_dict['WatermarkValue'] else v_default_watermark_value

        # Compute Source Table and Query - Load into Dataframe base on Query Template
        v_full_source_schema = p_config_row_dict['DestinationCatalog']+'.'+p_config_row_dict['DestinationSchema']
        v_full_source_table =  v_full_source_schema+'.'+p_config_row_dict['DestinationTable']
        sql_source_query = p_config_row_dict['IncrementalExtractQuery'] if p_config_row_dict['IsHistorical'] == 0 else p_config_row_dict['HistoricalExtractQuery']
        sql_source_query = sql_source_query.format(  UC_SchemaName=v_full_source_schema, 
                                                    TableName=p_config_row_dict['DestinationTable'],
                                                    SiteId=p_config_row_dict['SourceFacilityId'],
                                                    SourceFacilityId=p_config_row_dict['SourceFacilityId'],
                                                    WatermarkValue=f"'{v_watermark_value}'",
                                                    SourceFacilityCode=f"'{p_config_row_dict['SourceFacilityCode']}'",
                                                    SourceClientId=p_config_row_dict['SourceClientId'],
                                                    SourceClientCode=f"'{p_config_row_dict['SourceClientCode']}'",
                                                    InternalClientId=p_config_row_dict['InternalClientId']
                                                )
        # Load Source Table based on Query
        #print(f"Executing Query Data Split: {sql_source_query}")
        df_source = spark.sql(sql_source_query) 

        # Build Filter Dictionary
        v_filter_dict = {}
        if p_config_row_dict['InternalClientId'] > 0:
            v_filter_dict['InternalClientId'] = p_config_row_dict['InternalClientId']
        if p_config_row_dict['InternalFacilityId'] > 0:
            v_filter_dict['InternalFacilityId'] = p_config_row_dict['InternalFacilityId']
        v_ids_filter = ' and '.join([ k+'='+str(v) for k,v in v_filter_dict.items()])
        
        # Add Internal Client/Facility
        if 'InternalClientId' in v_filter_dict: 
                df_source = df_source.withColumn('InternalClientId', lit(v_filter_dict['InternalClientId']))
        if 'InternalFacilityId' in v_filter_dict:
            df_source = df_source.withColumn('InternalFacilityId', lit(v_filter_dict['InternalFacilityId']))

        # Add Audit Columns
        df_source = df_source.withColumn('DateTimeCreated', current_timestamp()) \
                            .withColumn('CreatedByUser', current_user()) \
                            .withColumn('DateTimeLastModified', current_timestamp()) \
                            .withColumn('ModifiedByUser', current_user()) \
                            .withColumn('DataSourceSchemaName', lit(p_config_row_dict['DestinationSchema']))
        
        # Build Client Table Full Name
        v_full_client_table = p_config_row_dict['DestinationCatalog']+'.'+p_config_row_dict['ClientODSSchema']+'.'+p_config_row_dict['ClientODSTable']

        #TESTING FAILURES
        # if p_config_row_dict['DestinationTable'] == 'ClaimDataUpdated':
        #     raise Exception("Test Exception")
        
        v_retry_status = "Failed"
        while v_retry_count < self.get_max_retries():
            try:
                # Create Client Schema if not exists
                v_location_path = p_config_row_dict['ManagedLocationRootPath']+'/'+p_config_row_dict['ClientODSSchema']
                uc = UnityCatalogTableOperations()
                uc.create_schema(p_config_row_dict['DestinationCatalog'], p_config_row_dict['ClientODSSchema'],v_location_path)

                # Create Table if does not exists ......
                if not spark.catalog.tableExists(v_full_client_table):
                    # # Partition By InternalFacilityId if exists
                    # if 'InternalFacilityId' in v_filter_dict:
                    #     df_source.write.partitionBy('InternalFacilityId').saveAsTable(v_full_client_table)
                    # else:                
                    #     df_source.write.saveAsTable(v_full_client_table)

                    # Create Table with deletetion vectors enabled.
                    df_source.write \
                        .format("delta") \
                        .option("delta.enableDeletionVectors", "true") \
                        .saveAsTable(v_full_client_table)

                    # Add Liquid Clustering By AUTO
                    uc.table_add_cluster_key(v_full_client_table)

                    # Add Primary Keys if exists with RELY Option
                    if p_config_row_dict['KeyColumns']:
                        v_keys_list = p_config_row_dict['KeyColumns'].split(',')
                        uc.table_add_primary_key(v_full_client_table, v_keys_list, True)

                    v_return_dict['num_created_rows'] = df_source.count()

                else:
                    # Delete Rows on Client Table if is Historical
                    if p_config_row_dict['IsHistorical'] == 1:
                        v_delete_filter = f" WHERE {v_ids_filter}" if len(v_ids_filter) > 0 else ''
                        spark.sql(f"DELETE FROM {v_full_client_table} {v_delete_filter};")

                    # Merge if Keys Exists
                    if p_config_row_dict['KeyColumns']:
                        # Create Temp View from Dataframe
                        unique_id = uuid.uuid4().hex
                        v_vw_source_data = f"tmp_{p_config_row_dict['ClientODSSchema']}_{p_config_row_dict['ClientODSTable']}_{unique_id}"
                        df_source.createOrReplaceTempView(v_vw_source_data)
                        # Construct Key List
                        v_key_list = p_config_row_dict['KeyColumns'].split(',')
                        if 'InternalFacilityId' in v_filter_dict:
                            v_key_list.append('InternalFacilityId')
                            
                        # Call Merge Function
                        v_retun_merge_dict = uc.merge_table(v_vw_source_data, v_full_client_table, v_ids_filter, v_key_list, v_create_audit_dict, v_update_audit_dict)
                        
                        # Update Return dictionary
                        v_return_dict.update(v_retun_merge_dict)

                    # If No Keys, Append data if 'Watermark' specified otherwise overwrite
                    else:
                        mode = 'append' if 'WaterMark' in p_config_row_dict['IncrementalExtractQuery'] else 'overwrite' 
                        df_source.write.mode(mode).saveAsTable(v_full_client_table)

                    v_retry_status = "Success"
                    break
            except Exception as e:
                v_retry_count += 1
                v_delay = self.get_random_retry_delay()
                v_attempt_message = f"{e}"
                print(f"Attempt {v_retry_count} failed (Sleep {v_delay} seconds): {p_config_row_dict['ProcessIdentifier']} {e}")
                time.sleep(v_delay)
                
        v_return_dict['duration'] = round((datetime.now() - v_start).total_seconds())
        v_return_dict['status'] = v_retry_status
        v_return_dict['retry_count'] = v_retry_count 

        # On Failure - Update Status to 'F' and raise Error
        if v_retry_status != 'Success':
            p_config_row_dict['process_config'].set_data_load_process_status_detail(v_row_status_dict, ) 
            raise Exception(f"End - All {v_retry_count} Attempts Failed - Processing ({v_return_dict['duration']}s): {p_config_row_dict['ProcessIdentifier']} - {v_attempt_message}")
                      
        # On Success - Update Status to 'S'
        v_row_status_dict['Status'] = 'S'
        p_config_row_dict['process_config'].set_data_load_process_status_detail(v_row_status_dict) 

        print(f"End - Processing ({v_return_dict['duration']}s): ", p_config_row_dict['ProcessIdentifier'])
        return v_return_dict

    def validate_config_row(self, p_config_row_dict):
        validation_message = ''
        #print("Validation\n", p_config_row_dict)
        if p_config_row_dict['IncrementalExtractQuery'] is None:
            validation_message += 'IncrementalExtractQuery is required\n'
        if p_config_row_dict['HistoricalExtractQuery'] is None:
            validation_message += 'HistoricalExtractQuery is required\n'
        if p_config_row_dict['ClientODSSchema'] is None:
            validation_message += 'ClientODSSchema is required\n'
        if p_config_row_dict['ClientODSTable'] is None:
            validation_message += 'ClientODSTable is required\n'
        if p_config_row_dict['DestinationSchema'] is None:
            validation_message += 'DestinationSchema is required\n'
        if p_config_row_dict['DestinationTable'] is None:
            validation_message += 'DestinationTable is required\n'
        if 'DestinationCatalog' not in p_config_row_dict or \
            p_config_row_dict['DestinationCatalog'] is None:
            validation_message += 'DestinationCatalog is required\n'

        if validation_message == "":
            v_str = p_config_row_dict['DestinationCatalog']+'.'+p_config_row_dict['DestinationSchema']+'.'+p_config_row_dict['DestinationTable']
            if not spark.catalog.tableExists(v_str):
                validation_message += f"Source Streaming Table {v_str} is missing\n"
        
        return validation_message if validation_message != '' else None
    
        


In [0]:
# obj = ClientDataSplit('dev', 27, 'lewvpalyedb04.nthext.com', None, None, None, 'ingst_sqlcdc_global_lewvpalyedb04_x3domain1')
# obj.process()