In [0]:
import pprint
import traceback
from pyspark.sql.functions import collect_list, concat_ws, expr, col, length, lit, concat, lower, when, round, trim
from pyspark.sql.types import StructType,StructField, StringType, LongType
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed

In [0]:
%run ./UtilityHandler

In [0]:
%run ./GlobalVariables

In [0]:
%run ./ProcessConfigHandler

In [0]:
class LakeflowCountReconciliation():
    def __init__(self, p_environment,  p_internal_product_id, p_source_server_name, p_source_database_name=None, p_internal_client_id=None, p_internal_facility_id=None,p_ingestion_pipeline_name=None):
        self.params_dict = {
            'p_environment': p_environment,
            'p_internal_product_id': p_internal_product_id,
            'p_source_server_name': p_source_server_name,
            'p_source_database_name': p_source_database_name,
            'p_internal_client_id': p_internal_client_id,
            'p_internal_facility_id': p_internal_facility_id,
            'p_ingestion_pipeline_name': p_ingestion_pipeline_name
        }

        self.process_max_workers = 10
        self.trackeback_length = 1000
        self.saveas_table = 'db_config.ProcessReconciliationCounts'

    def set_process_max_workers(self, process_max_workers):
        self.process_max_workers = process_max_workers

    def get_process_max_workers(self):
        return self.process_max_workers
    
    def set_traceback_length(self, traceback_length):
        self.trackeback_length = traceback_length

    def get_traceback_length(self):
        return self.trackeback_length
    
    def set_saveas_table(self, saveas_table):
        self.saveas_table = saveas_table

    def get_saveas_table(self):
        return self.saveas_table

    def process(self):

        v_return_dict = {'Total': 0, 'Success': 0, 'SuccessResults': [], 'ExecutionFailed' : 0, 'ExecutionFailures': []}
        v_thread_errors = []       # To collect all exceptions
        v_thread_results = []      # To collect successful results
        v_row_array = []
        
         # Retrieve List of Source and Destination Systems 
        process_config = ProcessConfigData(self.params_dict["p_environment"])
        df_config = process_config.get_table_list_aggregate (
                                                                self.params_dict["p_internal_product_id"], 
                                                                self.params_dict["p_source_server_name"], 
                                                                self.params_dict["p_source_database_name"], 
                                                                self.params_dict["p_internal_client_id"], 
                                                                self.params_dict["p_internal_facility_id"],
                                                                self.params_dict['p_ingestion_pipeline_name']
                                                            )
        
        df_existing = df_config.filter((col("IngestionPipelineId").isNotNull()) & (trim(col("IngestionPipelineId")) != ""))

        # Get Unitity Catalog and Managed Location Root Path from Process Config Attributes
        v_unity_catalog = process_config.get_config_attribute_value('AnalyticsUnityCatalog')
        v_managed_location_root_path = process_config.get_config_attribute_value('AdlsAnalyticsFullpathUri')

        # Construct Row Array for Processing Existing Pipelines and adding common Attributes
        for row in df_existing.collect():
            v_row_dict = row.asDict()
            v_row_dict['Status'] = 'Pending'
            v_row_dict['SourceType'] = 'SqlServer'
            v_row_dict['DestinationType'] = 'UnityCatalog'
            v_row_dict['DestinationCatalog'] = v_unity_catalog
            v_row_dict['ManagedLocationRootPath'] = v_managed_location_root_path
            v_row_dict['process_config'] = process_config
            v_row_dict['ProcessIdentifier'] = f"Source={v_row_dict['SourceServerName1']}/{v_row_dict['SourceDatabaseName1']} <-> Destination={v_row_dict['DestinationCatalog']}/{v_row_dict['DestinationSchema']}"
            v_row_dict['Params'] = self.params_dict
            v_row_array.append(v_row_dict)

        # Launch Threads for Process Existing Pipelines in parallell
        with ThreadPoolExecutor(max_workers=self.get_process_max_workers()) as executor:
            futures = {executor.submit(self.process_row, obj): obj for obj in v_row_array}
            for future in as_completed(futures):
                obj = futures[future]
                try:
                    result = future.result()
                    v_thread_results.append(str(obj['ProcessIdentifier']) + ' \nMessage=Count Reconciliation Successfully retrieved - : ' + str(result))
                    #print(f"Thread Successes: ", v_thread_results)
                except Exception as e:
                    v_thread_errors.append(str(obj['ProcessIdentifier']) + ' \nMessage=' + str(e).replace('\n','')[0:self.get_traceback_length()] + ' \nTraceback=' + traceback.format_exc().replace('\n','')[0:self.get_traceback_length()] )
                    #print(f"Thread Errors: ", v_thread_errors)

        # Collect Statistics for Summary
        v_return_dict['Total'] = len(v_row_array)
        v_return_dict['Success'] = len(v_thread_results)
        v_return_dict['SuccessResults'] = v_thread_results
        v_return_dict['ExecutionFailed'] = len(v_thread_errors)
        v_return_dict['ExecutionFailures'] = v_thread_errors

        print("\n\nSummary:\n")
        pprint.pprint(v_return_dict, indent=4, compact=True, sort_dicts=False, width=10000)
        # Raise Exception if error or nothing to process
        if len(v_thread_errors) > 0:
            raise Exception(f"{len(v_thread_errors)} Execution errors were encountered during processing. See exception(s) for details.")
        elif len(v_row_array) == 0:
            raise Exception(f"Unable to find valid configuration(s) for Parameters: {self.params_dict}")
        else:
            print(f"Processing Complete. {len(v_thread_results)} of {len(v_row_array)} Reconciliation Tasks were processed successfully.\n\n")
        
        return v_return_dict
    
    def process_row(self, p_row_dict):
        obj_recon = CountReconciliation(p_row_dict['Params']['p_environment'],
                                                   p_row_dict['SourceType'],
                                                   p_row_dict['SourceServerName1'],
                                                   p_row_dict['SourceDatabaseName1'],
                                                   p_row_dict['DataSourceShortName'],
                                                   p_row_dict['DestinationType'],
                                                   p_row_dict['DestinationCatalog'],
                                                   p_row_dict['DestinationSchema'])
        df_result = obj_recon.process()

        if df_result and df_result.count() > 0:
            df_result.write.mode('append').saveAsTable(self.get_saveas_table())

        return df_result.count() if df_result else 0


In [0]:
class CountReconciliation():
    def __init__(self, p_environment, p_source_type, p_source_server_name, p_source_database_name, p_source_short_name, p_destination_type,  p_destination_catalog, p_destination_schema):
        self.params_dict = {
            'p_environment': p_environment,
            'p_source_type': p_source_type,
            'p_source_server_name': p_source_server_name,
            'p_source_database_name': p_source_database_name,
            'p_source_short_name': p_source_short_name,
            'p_destination_type': p_destination_type,
            'p_destination_catalog': p_destination_catalog,
            'p_destination_schema': p_destination_schema
        }
        gv_obj = GlobalVars(p_environment)
        self.dbx_scope = gv_obj.get_database_config('dbx_scope')
        

    def get_sqlsvr_counts(self, p_source_type, p_source_server_name, p_source_database_name, p_source_short_name):

        jdbcUsername = "cdc_analytics"
        jdbcPassword = dbutils.secrets.get(scope = self.dbx_scope, key = "EtlSqlCDCSecret")
        jdbcHostname = f"{p_source_server_name}.NTHRIVE.NTHCRP.COM" if '.com' not in p_source_server_name.lower() else p_source_server_name
        jdbcPort = 1433
        jdbcDatabase = p_source_database_name
        jdbcHostNamePort = f"{jdbcHostname}:{jdbcPort}"

        jdbcQuery = f"""select distinct s.name as SourceSchema, t.name as SourceTable, p.rows as SourceCount
                        from sys.partitions p
                        join sys.tables t on t.object_id = p.object_id
                        join sys.schemas s on t.schema_id = s.schema_id
                        where t.is_tracked_by_cdc = 1
                        """
        #print(jdbcQuery)

        df_remote_query = (spark.read
        .format("sqlserver")
        .option("host", jdbcHostname)
        .option("port", jdbcPort) # optional, can use default port 1433 if omitted
        .option("user", jdbcUsername)
        .option("password", jdbcPassword)
        .option("database", jdbcDatabase)
        #.option("authentication", "ActiveDirectoryPassword")
        .option("trustServerCertificate","true")
        .option("hostNameInCertificate","*.NTHRIVE.NTHCRP.COM")
        .option("query", jdbcQuery) # (if schemaName not provided, default to "dbo")
        .load()) \
        .withColumn("SourceType", lit(p_source_type)) \
        .withColumn("SourceServerName1", lit(p_source_server_name)) \
        .withColumn("SourceDatabaseName1", lit(p_source_database_name)) \
        .withColumn('tbl', concat(lit(p_source_short_name), lit('_ods_'), lower(col('SourceTable')))  ) \
        .orderBy('tbl')

        #df_remote_query.display()

        return df_remote_query
    
    def get_destination_streaming_counts(self, p_destination_type, p_destination_catalog, p_destination_schema, p_source_short_name):
        # Construct SQL Count Statement for each table
        sql_text = f"""
            select "select "  || "'" || table_name || "'" || ' as DestinationTable, count(*) as DestinationCount from ' || table_catalog || '.' || table_schema || '.' || table_name || ' union ' as value
            from {p_destination_catalog}.information_schema.tables where table_schema = '{p_destination_schema}' and table_catalog = '{p_destination_catalog}' and table_name not like '__%' escape '_'
            order by 1
        """
        
        # Run Construct Statement to get table list
        df_stream_list = spark.sql(sql_text)
        
        # Combine all separate statements into one statement
        df_stream_statement = df_stream_list \
                    .agg(collect_list("value").alias("combined_values")) \
                    .withColumn("combined_values", concat_ws(" ", "combined_values"))  
    
        #df_stream_statement.display()

        if df_stream_statement.count() > 0:
            # Get the combined statement as text
            sql_result_text = df_stream_statement.collect()[0]['combined_values']
            sql_result_text =  StringUtils.remove_last_occurrence(sql_result_text, ' union ') + ' order by 1 ' # sql_result_text.rstrip(' union ')

            # Calculate Counts
            df_result = spark.sql(sql_result_text) \
                        .withColumn("DestinationType", lit(p_destination_type)) \
                        .withColumn("DestinationCatalog", lit(p_destination_catalog)) \
                        .withColumn("DestinationSchema", lit(p_destination_schema)) \
                        .withColumn('tbl', concat(lit(p_source_short_name), lit('_ods_'), col('DestinationTable'))  )  
            #df_result.show(100, truncate=False)
            #print(datetime.now(), df_result.count())

        return df_result
    
    def get_recon_source_stream(self, p_source_type, p_source_server_name, p_source_database_name, p_source_short_name, p_destination_type, p_destination_catalog, p_destination_schema):
        # Get Counts on Source and Stream
        df_remote = self.get_sqlsvr_counts(p_source_type, p_source_server_name, p_source_database_name, p_source_short_name)
        df_stream = self.get_destination_streaming_counts(p_destination_type, p_destination_catalog, p_destination_schema, p_source_short_name)
        # Get Report Date Information
        report_date = datetime.now().strftime("%Y-%m-%d")
        report_hour = datetime.now().strftime("%H")
        report_minute = datetime.now().strftime("%M")
        # Join Source and Stream
        df_recon = df_remote.join(df_stream, ['tbl'], 'fullouter') \
                    .withColumn('DiffCount', round(col('SourceCount') - col('DestinationCount'), 0)) \
                    .withColumn('DiffPercentage', round(col('DiffCount') / col('SourceCount') * 100, 2)) \
                    .withColumn('ReportDate', lit(report_date)) \
                    .withColumn('ReportHour', lit(report_hour)) \
                    .withColumn('ReportMinute', lit(report_minute))

        df_recon.display()
        return df_recon
    
    def process(self):
        return self.get_recon_source_stream(self.params_dict['p_source_type'],
                                                self.params_dict['p_source_server_name'], 
                                                self.params_dict['p_source_database_name'], 
                                                self.params_dict['p_source_short_name'], 
                                                self.params_dict['p_destination_type'],
                                                self.params_dict['p_destination_catalog'], 
                                                self.params_dict['p_destination_schema'])
        
       

In [0]:
# obj = LakeflowCountReconciliation('dev', 27,'lewvpalyedb04.nthext.com', None, None, None,'ingst_sqlcdc_global_lewvpalyedb04_100001')
# obj.process()