In [0]:
%run ./GatewayManagementHandler

In [0]:
%run ./IngestionManagementHandler

In [0]:
import json
import pprint
import time
from datetime import datetime
import traceback
from pyspark.sql.functions import trim, col
from concurrent.futures import ThreadPoolExecutor, as_completed

In [0]:
class IngestionOrchestrator(SqlIngestionCommon):
    def __init__(self, p_environment,  p_internal_product_id, p_source_server_name, p_source_database_name=None, p_internal_client_id=None, p_internal_facility_id=None):
        super().__init__()
        self.params_dict = {
            'p_environment': p_environment,
            'p_internal_product_id': p_internal_product_id,
            'p_source_server_name': p_source_server_name,
            'p_source_database_name': p_source_database_name,
            'p_internal_client_id': p_internal_client_id,
            'p_internal_facility_id': p_internal_facility_id
        }

        self.process_max_workers = 10
        self.trackeback_length = 1000

    def set_process_max_workers(self, process_max_workers):
        self.process_max_workers = process_max_workers

    def get_process_max_workers(self):
        return self.process_max_workers
    
    def set_traceback_length(self, traceback_length):
        self.trackeback_length = traceback_length

    def get_traceback_length(self):
        return self.trackeback_length

    def process(self):
        v_return_dict = {'Total': 0, 'Success': 0, 'SuccessResults': [], 'ExecutionFailed' : 0, 'ExecutionFailures': []}

        # Retrieve List of Pipelines to Run for Ingestion
        process_config = ProcessConfigData(self.params_dict["p_environment"])
        df_config = process_config.get_table_list_aggregate (
                                                                self.params_dict["p_internal_product_id"], 
                                                                self.params_dict["p_source_server_name"], 
                                                                self.params_dict["p_source_database_name"], 
                                                                self.params_dict["p_internal_client_id"], 
                                                                self.params_dict["p_internal_facility_id"]
                                                            )

        # Separate New and Existing Pipelines
        df_new = df_config.filter((col("IngestionPipelineId").isNull()) | (trim(col("IngestionPipelineId")) == ""))
        df_existing = df_config.filter((col("IngestionPipelineId").isNotNull()) & (trim(col("IngestionPipelineId")) != ""))

        # Process New Pipelines - Not Implemented
        if df_new.count() > 0:
            print(f"New Pipelines to be implemented: {df_new.select('IngestionPipelineName').collect()}")
        else:
            print("No new Pipelines to process")

        # Process Existing Pipelines
        if df_existing.count() > 0:
            print(f"Existing Pipelines to process: {df_existing.select('IngestionPipelineName').collect()}")
        else:
            print("No existing Pipelines to process")

        v_thread_errors = []       # To collect all exceptions
        v_thread_results = []      # To collect successful results
        v_row_array = []

        # Get Unitity Catalog and Managed Location Root Path from Process Config Attributes
        v_unity_catalog = process_config.get_config_attribute_value('AnalyticsUnityCatalog')
        v_managed_location_root_path = process_config.get_config_attribute_value('AdlsAnalyticsFullpathUri')

        # Construct Row Array for Processing Existing Pipelines and adding common Attributes
        for row in df_existing.collect():
            v_row_dict = row.asDict()
            v_row_dict['Status'] = 'Pending'
            v_row_dict['DestinationCatalog'] = v_unity_catalog
            v_row_dict['ManagedLocationRootPath'] = v_managed_location_root_path
            v_row_dict['process_config'] = process_config
            v_row_dict['ProcessIdentifier'] = f"Pipeline={v_row_dict['GatewayPipelineName']} Server={v_row_dict['SourceServerName1']} Database={v_row_dict['SourceDatabaseName1']}"
            v_row_array.append(v_row_dict)

        # Launch Threads for Process Existing Pipelines in parallell
        with ThreadPoolExecutor(max_workers=self.get_process_max_workers()) as executor:
            futures = {executor.submit(self.process_row, obj): obj for obj in v_row_array}
            for future in as_completed(futures):
                obj = futures[future]
                try:
                    result = future.result()
                    v_pipeline_id = result['pipeline_id'] if 'pipeline_id' in result else None
                    v_thread_results.append(str(obj['ProcessIdentifier']) + ': Message=Ingestion Successfully Completed - : Summary' + str(result))
                    #print(f"Thread Successes: ", v_thread_results)
                except Exception as e:
                    v_thread_errors.append(str(obj['ProcessIdentifier']) + ': Message=' + str(e).replace('\n','')[0:self.get_traceback_length()] + ' Summary:' + ' Traceback:' + traceback.format_exc().replace('\n','')[0:self.get_traceback_length()] )
                    #print(f"Thread Errors: ", v_thread_errors)
        # Collect Statistics for Summary
        v_return_dict['Total'] = len(v_row_array)
        v_return_dict['Success'] = len(v_thread_results)
        v_return_dict['SuccessResults'] = v_thread_results
        v_return_dict['ExecutionFailed'] = len(v_thread_errors)
        v_return_dict['ExecutionFailures'] = v_thread_errors

        print("\n\nSummary:\n")
        pprint.pprint(v_return_dict, indent=4, compact=True, sort_dicts=False, width=10000)
        # Raise Exception if error or nothing to process
        if len(v_thread_errors) > 0:
            raise Exception(f"{len(v_thread_errors)} Execution errors were encountered during processing. See exception(s) for details.")
        elif len(v_row_array) == 0:
            raise Exception(f"Unable to find valid configuration(s) for Parameters: {self.params_dict}")
        else:
            print(f"Processing Complete. {len(v_thread_results)} of {len(v_row_array)} pipelines were processed successfully.\n\n")
        
        return v_return_dict
    
    def process_row(self, p_row_dict):

        v_return_dict = {"started": {}, "edited": {}, "full_refresh": {}, "refresh": {}, "status": {}}

        # Get API Object
        api_obj = LakeflowAPI()

        # Start Related Gateway
        v_gateway_state = self.start_gateway_pipeline(p_row_dict, api_obj)
        
        # Delay before Starting Pipeline
        self.delay_pipeline_start(p_row_dict, api_obj)

        # Start Ingestion pipeline
        v_return_start_dict = self.start_ingestion_pipeline(p_row_dict, api_obj)
        v_return_dict.update(v_return_start_dict)

        # Stop Related Gateway
        v_state_dict = self.stop_gateway_pipeline(p_row_dict, api_obj)

        if 'IngestionState' in v_state_dict and v_state_dict['IngestionState'] == 'FAILED':
            raise Exception (f"Error: Ingestion Pipeline {p_row_dict['IngestionPipelineName']} failed - State: {v_state_dict['IngestionState']} - {v_return_dict} ")
        
        return v_return_dict   

    def start_ingestion_pipeline(self, p_row_dict, api_obj=None):
        v_return_dict = {"started": {}, "edited": {}, "full_refresh": {}, "refresh": {}, "status": {}}
        v_config_table_array = []
        v_pipeline_table_array = []
        v_config_table_full_refresh_array = []

        print(f"Initialize Ingestion Pipeline {p_row_dict['IngestionPipelineName']}....")

        # Set Process Status to Pending
        self.set_pipeline_process_status(p_row_dict, 'P')

        # Get Pipeline Specs for Comparison
        v_pipeline_dict = api_obj.get_pipeline(p_row_dict['IngestionPipelineId'])

        if v_pipeline_dict['status'] == 'ok':

            # Compare Config Tables vs Pipeline Tables
            v_config_table_array = self.build_config_source_table_array(p_row_dict)
            v_config_table_full_refresh_array = self.build_config_source_table_array(p_row_dict, 1)
            v_pipeline_table_array = v_pipeline_dict['response']['spec']['ingestion_definition']['objects']
            v_diff_dict = self.compare_config_vs_pipeline_tables(v_config_table_array, v_pipeline_table_array, v_config_table_full_refresh_array)

            # Save Snapshot Start Time
            v_snapshot_start_time = self.get_snapshot_startime()

            # Based on Compare Config Results - Start/Edit Pipeline, Perform Incremental or Full Refresh
            if v_diff_dict["added_count"] == 0 and v_diff_dict["removed_count"] == 0 and v_diff_dict["full_refresh_count"] == 0:
                # Start Pipeline Refresh
                print(f"Pipeline={p_row_dict['IngestionPipelineName']} - Incremental Refresh of {v_diff_dict['config_count']} Normal Run/No Changed Tables")
                v_return_dict["started"] = api_obj.start_pipeline(p_row_dict['IngestionPipelineId'])
                v_return_dict["status"]["started"] = v_return_dict["started"]["status"]
                self.wait_for_ingestion_pipeline_idle(p_row_dict, api_obj)
            elif v_diff_dict["added_count"] == 0 and v_diff_dict["removed_count"] == 0 and v_diff_dict["full_refresh_count"] ==  v_diff_dict["config_count"]:
                # Full Refresh - All Tables
                print(f"Pipeline={p_row_dict['IngestionPipelineName']} - Full Refresh of ALL {v_diff_dict['config_count']} Tables")
                v_return_dict["full_refresh"] = api_obj.start_pipeline(p_row_dict['IngestionPipelineId'], True)
                v_return_dict["status"]["full_refresh"] = v_return_dict["full_refresh"]["status"]
                self.wait_for_snapshots_complete(p_row_dict, v_diff_dict['config_count'], v_snapshot_start_time)
                v_return_dict["full_refresh"] = api_obj.start_pipeline(p_row_dict['IngestionPipelineId'])
                self.wait_for_ingestion_pipeline_idle(p_row_dict, api_obj)
            else:
                if v_diff_dict["full_refresh_count"] > 0:
                    # Full Refresh List of Tables
                    print(f"Pipeline={p_row_dict['IngestionPipelineName']} - Full Refresh of {v_diff_dict['full_refresh_count']} Table(s)")
                    v_table_name_list = [ tbl['table']['source_table'] for tbl in v_diff_dict["full_refresh"]]
                    v_return_dict["full_refresh"] = api_obj.start_pipeline(p_row_dict['IngestionPipelineId'], False, None, v_table_name_list)
                    v_return_dict["status"]["full_refresh"] = v_return_dict["full_refresh"]["status"]
                    self.wait_for_snapshots_complete(p_row_dict, v_diff_dict["full_refresh_count"], v_snapshot_start_time)
                    v_return_dict["full_refresh"] = api_obj.start_pipeline(p_row_dict['IngestionPipelineId'])
                    self.wait_for_ingestion_pipeline_idle(p_row_dict, api_obj)

                if v_diff_dict["refresh_count"] > 0:
                    # Incremental Refresh List of Tables
                    print(f"Pipeline={p_row_dict['IngestionPipelineName']} - Incremental Refresh of {v_diff_dict['refresh_count']} Table(s)")
                    v_table_name_list = [ tbl['table']['source_table'] for tbl in v_diff_dict["refresh"]]
                    v_return_dict["refresh"] = api_obj.start_pipeline(p_row_dict['IngestionPipelineId'], False, v_table_name_list)
                    v_return_dict["status"]["refresh"] = v_return_dict["refresh"]["status"]
                    self.wait_for_ingestion_pipeline_idle(p_row_dict, api_obj)
                    
                if (v_diff_dict["added_count"] + v_diff_dict["removed_count"]) > 0:
                    # Edit Pipeline - Add/Remove Tables
                    print(f"Pipeline={p_row_dict['IngestionPipelineName']} - Edit Pipeline - Add/Remove {v_diff_dict['added_count']}/{v_diff_dict["removed_count"]} Table(s)")
                    v_edit_json_dict = v_pipeline_dict['response']['spec']
                    v_edit_json_dict['ingestion_definition']['objects'] = v_config_table_array
                    v_return_dict["edited"] = api_obj.update_pipeline(p_row_dict['IngestionPipelineId'], v_edit_json_dict)
                    v_return_dict["status"]["edited"] = v_return_dict["edited"]["status"]
                    self.wait_for_ingestion_pipeline_idle(p_row_dict, api_obj)
                    if v_diff_dict["added_count"] > 0:
                        # Trigger Pipeline Full Refresh for Added Tables
                        v_table_name_list = [ tbl['table']['source_table'] for tbl in v_diff_dict["added_count"]]
                        v_return_dict["added"] = api_obj.start_pipeline(p_row_dict['IngestionPipelineId'], False, None, v_table_name_list)
                        v_return_dict["status"]["added"] = v_return_dict["added"]["status"]
                        self.wait_for_ingestion_pipeline_idle(p_row_dict, api_obj)

        print(f"Start Ingestion Pipeline {p_row_dict['IngestionPipelineName']}....")

        return v_return_dict

    def compare_config_vs_pipeline_tables(self, p_config_array, p_pipeline_array, p_config_full_refresh_array=[]):
        # Initialize Return Dictionary with Counts and Lists
        v_return_dict = {"added": [], "removed": [], 'unchanged': [], "full_refresh": [], "refresh": [], 
                        "added_count": 0, "removed_count": 0, "unchanged_count": 0, "config_count": 0, "pipeline_count": 0, "full_refresh_count": 0, "refresh_count": 0 }
        
        v_pipeline_dict = {}
        v_config_dict = {}
        v_config_full_refresh_dict = {}

        # Construct unique Ids for comparison - List of Tables Retrieved from Config
        for v_config in p_config_array:
            v_id = v_config["table"]["source_catalog"] + "." + v_config["table"]["source_schema"] + "." + v_config["table"]["source_table"] + "|" + v_config["table"]["destination_catalog"] + "." + v_config["table"]["destination_schema"] + "." + v_config["table"]["destination_table"]
            v_config_dict[v_id] = v_config

        # Construct unique Ids for comparison - List of Tables Retrieved from Pipeline
        for v_pipeline in p_pipeline_array:
            v_id = v_pipeline["table"]["source_catalog"] + "." + v_pipeline["table"]["source_schema"] + "." + v_pipeline["table"]["source_table"] + "|" + v_pipeline["table"]["destination_catalog"] + "." + v_pipeline["table"]["destination_schema"] + "." + (v_pipeline["table"]["destination_table"] if 'destination_table' in v_pipeline["table"] else v_pipeline["table"]["source_table"]) 
            v_pipeline_dict[v_id] = v_pipeline

        # Construct unique Ids for comparison - List of Tables Retrieved for Full Refresh from Config
        for v_config in p_config_full_refresh_array:
            v_id = v_config["table"]["source_catalog"] + "." + v_config["table"]["source_schema"] + "." + v_config["table"]["source_table"] + "|" + v_config["table"]["destination_catalog"] + "." + v_config["table"]["destination_schema"] + "." + v_config["table"]["destination_table"]
            v_config_full_refresh_dict[v_id] = v_config

        # Loop through list of tables from Pipeline and get what needs to be removed, fully refreshed and incremental refreshed
        for v_id in v_pipeline_dict:
            if v_id not in v_config_dict:
                v_return_dict["removed"].append(v_pipeline_dict[v_id])
            elif v_id in v_config_dict:
                v_return_dict["unchanged"].append(v_config_dict[v_id])
                if v_id in v_config_full_refresh_dict:
                    v_return_dict["full_refresh"].append(v_config_dict[v_id])
                else:
                    v_return_dict["refresh"].append(v_config_dict[v_id])  

        # Loop through list of tables from Config and get what needs to be added
        for v_id in v_config_dict:
            if v_id not in v_pipeline_dict:
                v_return_dict["added"].append(v_config_dict[v_id])

        # Collect Counts for decision making during processing
        v_return_dict["added_count"] = len(v_return_dict["added"])
        v_return_dict["removed_count"] = len(v_return_dict["removed"])
        v_return_dict["unchanged_count"] = len(v_return_dict["unchanged"])
        v_return_dict["full_refresh_count"] = len(v_return_dict["full_refresh"])
        v_return_dict["refresh_count"] = len(v_return_dict["refresh"])
        v_return_dict["config_count"] = len(v_config_dict)
        v_return_dict["pipeline_count"] = len(v_pipeline_dict)       

        # Return Counts and Lists
        return v_return_dict
    
    def build_config_source_table_array(self, p_row_dict, isHistorical=None):
        # Build List of Tables for Incremental or Full Refresh based on IsHistorical Flag
        return  [
                    {
                        "table": {
                            "source_catalog": p_row_dict['SourceDatabaseName1'],
                            "source_schema": table["SourceSchema"],
                            "source_table": table["SourceTable"],
                            "destination_catalog": p_row_dict['DestinationCatalog'],
                            "destination_schema": p_row_dict['DestinationSchema'],
                            "destination_table": table["DestinationTable"] 
                        }
                    }
                    for table in json.loads(p_row_dict['TableList'])["data"] if isHistorical is None or table['IsHistorical'] == isHistorical
                ]
                                         
    def get_api_obj(self):
        return LakeflowAPI()
    
    def get_process_config_object(self):
        return ProcessConfigData(self.params_dict['p_environment'])
    
    def set_pipeline_process_status(self, p_config_row_dict, p_status):
        v_process_status_dict = {   'PipelineName': p_config_row_dict['IngestionPipelineName'], \
                                    'InternalProductId' : p_config_row_dict['InternalProductId'], \
                                    'DataSourceId' : p_config_row_dict['DataSourceId'], \
                                    'StepType' : 'Extract', \
                                    'Status': p_status
                                }
        process_config = self.get_process_config_object()
        process_config.set_pipeline_process_status(v_process_status_dict)