In [0]:
%run ./ProcessConfigHandler

In [0]:
%run ./LakeflowAPIHandler

In [0]:
%run ./UnityCatalogHandler

In [0]:
import json
import pprint
from concurrent.futures import ThreadPoolExecutor, as_completed

In [0]:
class CreateSqlGatewayPipeline:
    def __init__(self, p_environment,  p_internal_product_id, p_source_server_name, p_source_database_name=None, p_internal_client_id=None, p_internal_facility_id=None):
        self.params_dict = {
            'p_environment': p_environment,
            'p_internal_product_id': p_internal_product_id,
            'p_source_server_name': p_source_server_name,
            'p_source_database_name': p_source_database_name,
            'p_internal_client_id': p_internal_client_id,
            'p_internal_facility_id': p_internal_facility_id
        }

        self.process_max_workers = 10
        self.trackeback_length = 1000

    def set_process_max_workers(self, process_max_workers):
        self.process_max_workers = process_max_workers

    def get_process_max_workers(self):
        return self.process_max_workers
    
    def set_traceback_length(self, traceback_length):
        self.trackeback_length = traceback_length

    def get_traceback_length(self):
        return self.trackeback_length
    
    def process(self):

        v_thread_errors = []        # To collect all exceptions
        v_thread_results = []       # To collect successful results
        v_validation_errors = []    # To collect validation errors
        v_skipped_results = []      # To collect skipped results
        v_config_row_array = []     # To collect valid config rows to process
        # Initialize Return Dictionary - will display results in that order using pprint
        v_return_dict = {'Total': 0,'Success': 0, 'SuccessResults': [], 'Skipped': 0, 'SkippedResults': [], 'ValidationFailed' : 0, 'ValidationFailures': [], 'ExecutionFailed' : 0, 'ExecutionFailures': []}
        # Retrieve Unity Catalog from Configuration and Manage Location Root Path
        process_config = ProcessConfigData(self.params_dict['p_environment'])
        v_unity_catalog = process_config.get_config_attribute_value('AnalyticsUnityCatalog')
        v_managed_location_root_path = process_config.get_config_attribute_value('AdlsAnalyticsFullpathUri')
        # Retrieve new pipelines Configuration Data
        df_config_data_rows = process_config.get_server_list(
                                                                                            self.params_dict['p_internal_product_id'],
                                                                                            self.params_dict['p_source_server_name'], 
                                                                                            self.params_dict['p_source_database_name'],
                                                                                            self.params_dict['p_internal_client_id'],
                                                                                            self.params_dict['p_internal_facility_id']
                                                                                        )
        # Raise exception if no configurations found
        if df_config_data_rows is None or df_config_data_rows.count() == 0:
            raise Exception(f"Unable to find implementation configuration for {self.params_dict}")
        else:
            v_return_dict['Total'] = df_config_data_rows.count()
            #print(df_config_data_rows, "---\n---", df_config_data_rows.count())
            #df_config_data_rows.display()
        # Loop through new Configurations - Skip if Gateway already created...
        for row in df_config_data_rows.collect():
            if row['GatewayPipelineId']:
                v_return_dict['Skipped'] += 1
                v_skipped_message = f"Warning: SKipping - PipelineID ({row['GatewayPipelineId']}) is filled in for this implementation: Pipeline={row['GatewayPipelineName']}, Server={row['SourceServerName1']}, Database={row['SourceDatabaseName1']}"
                v_skipped_results.append(v_skipped_message)
                print(v_skipped_message)
            else:
                # Get row as Dictionary and Validate Attributes
                v_config_row_dict = row.asDict()
                v_validation_message = self.validate_config_row(v_config_row_dict)
                # If Validate Successfully, add additional Attributes to row dictionary and save in array to process
                if v_validation_message is None:
                    v_config_row_dict['Status'] = 'Pending'
                    v_config_row_dict['DestinationCatalog'] = v_unity_catalog
                    v_config_row_dict['ManagedLocationRootPath'] = v_managed_location_root_path
                    v_config_row_dict['process_config'] = process_config
                    v_config_row_dict['ProcessIdentifier'] = f"Pipeline={v_config_row_dict['GatewayPipelineName']} Server={v_config_row_dict['SourceServerName1']} Database={v_config_row_dict['SourceDatabaseName1']}"
                    v_config_row_array.append(v_config_row_dict)
                    print(f"Implementation will proceed for {row['GatewayPipelineName']}, Server={row['SourceServerName1']}, Database={row['SourceDatabaseName1']}")
                else:
                    # Collect Validation Failure Details and save in array to display at process end.
                    v_return_dict['ValidationFailed'] += 1
                    v_validation_message = f"Implementation will NOT proceed for {row['GatewayPipelineName']}, Server={row['SourceServerName1']}, Database={row['SourceDatabaseName1']}: {v_validation_message}"
                    v_validation_errors.append(v_validation_message)
                    print(v_validation_message)

        if len(v_config_row_array) == 0:
            v_return_dict['ValidationFailures'] = v_validation_errors
            v_return_dict['SkippedResults'] = v_skipped_results
        else:
            # Launch New Gateways to Create based on Configuration in parallel...
            # Collect Thread Results and Errors
            print(f"Processing {len(v_config_row_array)} rows.....")
            with ThreadPoolExecutor(max_workers=self.get_process_max_workers()) as executor:
                futures = {executor.submit(self.process_config_row, obj): obj for obj in v_config_row_array}
                for future in as_completed(futures):
                    obj = futures[future]
                    try:
                        result = future.result()
                        v_pipeline_id = result['pipeline_id'] if 'pipeline_id' in result else None
                        v_thread_results.append(str(obj['ProcessIdentifier']) + ' \nMessage=Gateway Successfully Created - ID: ' + str(v_pipeline_id))
                        #print(f"Thread Successes: ", v_thread_results)
                    except Exception as e:
                        v_thread_errors.append(str(obj['ProcessIdentifier']) + ' \nMessage=' + str(e).replace('\n','')[0:self.get_traceback_length()])
                        #print(f"Thread Errors: ", v_thread_errors)
        # Collect Run Statistics
        v_return_dict['Success'] = len(v_thread_results)
        v_return_dict['SuccessResults'] = v_thread_results
        v_return_dict['SkippedResults'] = v_skipped_results
        v_return_dict['ValidationFailures'] = v_validation_errors
        v_return_dict['ExecutionFailed'] = len(v_thread_errors)
        v_return_dict['ExecutionFailures'] = v_thread_errors
        # Display Summary
        print("\n\nSummary:\n")
        pprint.pprint(v_return_dict, indent=4, compact=True, sort_dicts=False, width=10000)
        # Raise Exception if any errors were encountered or no Configurations to Process
        if len(v_thread_errors) + len(v_validation_errors) > 0:
            raise Exception(f"{len(v_thread_errors)} Execution errors and {len(v_validation_errors)} Validation Errors were encountered during processing. See exception for details.")
        elif len(v_config_row_array) == 0:
            raise Exception(f"Unable to find valid implementation configuration(s) for {self.params_dict} - Skipped: {len(v_skipped_results)}")
            
        return v_return_dict
            
            

    def validate_config_row(self, p_config_row_dict):
        validation_message = ''
        if p_config_row_dict['SourceServerName1'] is None:
            validation_message += 'SourceServerName1 is required\n'
        if p_config_row_dict['SourceDatabaseName1'] is None:
            validation_message += 'SourceDatabaseName1 is required\n'
        if p_config_row_dict['GatewayPipelineName'] is None:
            validation_message += 'GatewayPipelineName is required\n'
        if p_config_row_dict['DestinationSchema'] is None:
            validation_message += 'DestinationSchema is required\n'
        if p_config_row_dict['ConfigJSON'] is None:
            validation_message += 'ConfigJSON is required\n'
        else:
            try:
                v_json = json.loads(p_config_row_dict['ConfigJSON'])
                if 'initial_cluster_spec' not in v_json:
                    validation_message += 'ConfigJSON is missing "initial_cluster_spec"\n'
                if 'connection_name' not in v_json:
                    validation_message += 'ConfigJSON is missing "connection_name"\n' 
            except:
                validation_message += 'ConfigJSON is not valid JSON\n'
        
        return validation_message if validation_message != '' else None
    
    def build_pipeline_json_dict(self, p_config_row_dict):
        # Pick the Large Cluster Initially
        v_json_clusters_dict = json.loads(p_config_row_dict['ClusterJSONSpecsLarge'])
        
        return {
                "name": p_config_row_dict['GatewayPipelineName'],
                "catalog": p_config_row_dict['DestinationCatalog'],
                "target": p_config_row_dict['DestinationSchema'],
                "clusters": v_json_clusters_dict["clusters"],
                "gateway_definition":   {
                                            "connection_name" : p_config_row_dict['ConnectionName'],
                                            "gateway_storage_catalog": p_config_row_dict['DestinationCatalog'],
                                            "gateway_storage_schema" : p_config_row_dict['DestinationSchema'],
                                        }
                }
        
    def create_destination_schema(self, p_config_row_dict):
        v_uc_obj = UnityCatalog()
        v_schema_managed_location = f"{p_config_row_dict['ManagedLocationRootPath']}/{p_config_row_dict['DestinationSchema']}"
        return v_uc_obj.create_schema(p_config_row_dict['DestinationCatalog'], p_config_row_dict['DestinationSchema'], v_schema_managed_location)

    def process_config_row(self, p_config_row_dict):

        try:
            # This function can be called outside of the main process function, so we need to make sure we have the process_config object instantiated
            if 'process_config' not in p_config_row_dict:
                p_config_row_dict['process_config'] = process_config = ProcessConfigData(self.params_dict['p_environment'])

            v_api_json_data_dict = self.build_pipeline_json_dict(p_config_row_dict)
            self.create_destination_schema(p_config_row_dict)
            v_lakeflow_api = LakeflowGatewayAPI()
            reponse_dict = v_lakeflow_api.create_gateway_pipeline(v_api_json_data_dict)

            if reponse_dict['status'] == 'ok':
                p_config_row_dict['GatewayPipelineId'] = reponse_dict['pipeline_id']
                self.process_config_update_pipeline_id(p_config_row_dict)

                # Set Permissions if defined
                gb_vars_obj = GlobalVars(self.params_dict['p_environment'])
                v_permissions_list = gb_vars_obj.get_pipeline_permissions_list()
                if v_permissions_list is not None and len(v_permissions_list) > 0:
                    v_lakeflow_api.update_pipeline_permissions(p_config_row_dict['GatewayPipelineId'], v_permissions_list)
                
                # Wait up to 30 minutes for Gateway to be in Running State
                api_ext_obj = LakeflowAPIExtension()
                api_ext_obj.wait_for_gateway_running(p_config_row_dict['GatewayPipelineId'], 30)
            else:
                raise Exception (f"Error: Create Gateway Pipeline Failed for {p_config_row_dict['ProcessIdentifier']}: str({reponse_dict})")
            print("Response:", reponse_dict)
        except Exception as e:
            print(f"Error: {p_config_row_dict['ProcessIdentifier']} {e}")
            raise Exception(f"Error: {p_config_row_dict['ProcessIdentifier']} {e}")
        return reponse_dict

    def process_config_update_pipeline_id(self, p_config_row_dict):
        v_stored_procedure_params_dict = { 
                                          'PipelineType': 'Gateway',
                                          'InternalProductId': p_config_row_dict['InternalProductId'],
                                          'SourceServerName1': p_config_row_dict['SourceServerName1'],
                                          'SourceDatabaseName1': p_config_row_dict['SourceDatabaseName1'],
                                          'SourceConfigTable': p_config_row_dict['SourceConfigTable'],
                                          'PipelineID': p_config_row_dict['GatewayPipelineId'],
                                        }

        p_config_row_dict['process_config'].set_pipeline_id(v_stored_procedure_params_dict)

    def get_gateway_api(self):
        return LakeflowGatewayAPI()
          

In [0]:
class UpdateSqlGatewayPipeline:
    def __init__(self, p_pipeline_id):
        self.pipeline_id = p_pipeline_id
        self.lakeflow_api = LakeflowGatewayAPI()

    def get_pipeline_json(self):
        return self.lakeflow_api.get_pipeline(self.pipeline_id)
    
    def update_pipeline_clusters(self, p_cluster_list):
        return_dict = {'status': 'error'}

        v_json_dict = self.get_pipeline_json()
        return_dict['response'] = v_json_dict['response']

        if v_json_dict['status'] == 'ok':
            v_pipeline_spec = v_json_dict['response']['spec']
            v_pipeline_spec['clusters'] = p_cluster_list
            v_update_response_dict = self.lakeflow_api.update_pipeline(self.pipeline_id, v_pipeline_spec)
            return_dict.update(v_update_response_dict)
            if v_update_response_dict['status'] == 'ok':
                return_dict['message'] = 'Pipeline Clusters Updated'
            else:
                return_dict['message'] =  'Pipeline Clusters Updated Failed'  
        else:
            return_dict['message'] = f'Pipeline Get Definition Failed: See Response Object'

        return return_dict
    
    def update_pipeline_name(self, p_name):
        return_dict = {'status': 'error'}

        v_json_dict = self.get_pipeline_json()
        return_dict['response'] = v_json_dict['response']

        if v_json_dict['status'] == 'ok':
            v_pipeline_spec = v_json_dict['response']['spec']
            v_pipeline_spec['name'] = p_name
            v_update_response_dict = self.lakeflow_api.update_pipeline(self.pipeline_id, v_pipeline_spec)
            return_dict.update(v_update_response_dict)
            if v_update_response_dict['status'] == 'ok':
                return_dict['message'] = 'Pipeline Name Updated'
            else:
                return_dict['message'] =  'Pipeline Name Updated Failed'  
        else:
            return_dict['message'] = f'Pipeline Get Definition Failed: See Response Object'

        return return_dict
    
    def get_gateway_api(self):
        return LakeflowGatewayAPI()

In [0]:
# obj = UpdateSqlGatewayPipeline('8e62c559-6817-4330-bab9-109a0b89d194') 
# #('dba0fd7a-15cc-4eea-906c-fe84667220fc')
# print(obj.get_pipeline_json())
# # gw_sqlcdc_global_nprod_c2c_shared_sql_server_700084
# # # Standard_E4d_v4
# # # Standard_DS3_v2
# # v_clusters = [{"label": "default","driver_node_type_id": "Standard_E4d_v4","node_type_id": "Standard_DS3_v2","num_workers": 1}]
# # print(obj.update_pipeline_clusters(v_clusters))
# print(obj.update_pipeline_name('gw_sqlcdc_global_nprod_c2c_shared_sql_server_700084'))

In [0]:
# lakeflowsql = LakeFlowSqlGateway('dev', 27, 'nprod-c2c-shared-sql-server.database.windows.net')
# #lakeflowsql = LakeFlowSqlGateway('dev', 27, 'lewvpalyedb04.nthext.com')
# lakeflowsql.process()