In [0]:
%run ../classhandlers/UtilityHandler

In [0]:
class UnityCatalog:
    def __init__(self):
        pass

    def create_schema_fullname(self,  p_full_schema_name, p_schema_path=None):
        v_location_path = f"MANAGED LOCATION '{p_schema_path}'" if p_schema_path else ""
        spark.sql(f"CREATE SCHEMA IF NOT EXISTS {p_full_schema_name} {v_location_path};")
        return p_full_schema_name
    
    def create_schema(self, p_catalog_name, p_schema_name, p_schema_path=None):
        v_full_schema_name = f"{p_catalog_name}.{p_schema_name}"
        return self.create_schema_fullname(v_full_schema_name, p_schema_path)
    
    def create_table_like (self, p_table_full_name, p_table_like_full_name, p_partition_column_list=[]):
        sql_text = f"CREATE TABLE IF NOT EXISTS {p_table_full_name}"
        if len(p_partition_column_list) > 0:
            sql_text += f" PARTITIONED BY ({','.join(p_partition_column_list)})"
        sql_text += "  AS SELECT * FROM {p_table_like_full_name} where 1=0;"
        spark.sql(sql_text)
        return p_table_full_name
    
    def get_default_catalog(self):
        return spark.sql("SELECT current_catalog() as catalog_name").collect()[0][0]
    
    def get_table_column_names(self,p_table_full_name):
        return spark.sql(f"select * from {p_table_full_name} LIMIT 0;").columns if spark.catalog.tableExists(p_table_full_name) else []
    
    def table_add_cluster_key(self, p_table_full_name, p_cluster_column_list=None):
        if p_cluster_column_list:
            spark.sql(f"ALTER TABLE {p_table_full_name} CLUSTER BY ({','.join(p_cluster_column_list)});")
        else:
            spark.sql(f"ALTER TABLE {p_table_full_name} CLUSTER BY AUTO;")
        return p_table_full_name
    
    def table_add_primary_key(self, p_table_full_name, p_primary_key_column_list, p_rely=False):
        # Set Primary Columns to NOT NULL
        for col in p_primary_key_column_list:
            spark.sql(f"ALTER TABLE {p_table_full_name} ALTER COLUMN {col} SET NOT NULL;")
        # Add list of columns as Primary Key(s)
        sql_text = f"ALTER TABLE {p_table_full_name} ADD PRIMARY KEY ({','.join(p_primary_key_column_list)})"
        # Add rely clause if p_rely is True
        sql_text += " RELY" if p_rely else ""
        # Execute Alter statement
        spark.sql(f"{sql_text};")

        return p_table_full_name
    

In [0]:
class UnityCatalogTableExtension(UnityCatalog):
    def __init__(self):
        super().__init__()

    def create_table_union_sql_text(self, p_catalog, p_schema_like, p_table_like, except_columns_list=[]):
        sql_result_text = None
        except_select_text = ''
        if len(except_columns_list) > 0:
            except_select_text = ', '.join(except_columns_list)
            except_select_text = except_select_text.rstrip(', ')
            except_select_text = f" except ({except_select_text})"
    
        sql_text = f"""select concat("select * {except_select_text} from ", table_catalog,'.',table_schema,'.',table_name,' union') as value
                        from {p_catalog}.information_schema.tables where table_catalog = '{p_catalog}' and table_schema ilike '{p_schema_like}' and table_name ilike '{p_table_like}'"""

        # Strip last 'union' text
        sql_text = sql_text.rstrip('union')

        # Run Construct Statement to get table list
        df_table_list = spark.sql(sql_text)

        # Combine all separate statements into one statement
        df_combine_statement = df_table_list \
                    .agg(collect_list("value").alias("combined_values")) \
                    .withColumn("combined_values", concat_ws(" ", "combined_values"))  

        if df_combine_statement.count() > 0:
            # Get the combined statement as text
            sql_result_text = df_combine_statement.collect()[0]['combined_values']
            sql_result_text = sql_result_text.rstrip(' union ')
        
        return sql_result_text

In [0]:
class UnityCatalogTableOperations(UnityCatalog):
    def __init__(self):
        super().__init__()

    def merge_table(self, p_source_table, p_target_table, p_source_filter=None, p_keys=None, p_create_audit_columns_dict={}, p_update_audit_columns_dict={}):
        
        p_keys = [] if p_keys is None else [k.strip() for k in p_keys]
        p_create_audit_columns_dict = {} if p_create_audit_columns_dict is None else p_create_audit_columns_dict
        p_update_audit_columns_dict = {} if p_update_audit_columns_dict is None else p_update_audit_columns_dict
        p_source_filter = "" if p_source_filter is None else f"WHERE {p_source_filter}"

        return_dict = {'status': 'failed'}
        
        #check if tables exist
        if spark.catalog.tableExists(p_source_table) \
            and spark.catalog.tableExists(p_target_table):
            #Get Source and Target Column Lists
            v_source_column_list = self.get_table_column_names(p_source_table)
            v_target_column_list = self.get_table_column_names(p_target_table)
            
            #Find Common Columns and Remove Audit Columns
            v_common_columns = list(set(v_source_column_list).intersection(v_target_column_list))
            v_common_columns = list(set(v_common_columns).difference(p_create_audit_columns_dict.keys()))
            v_common_columns = list(set(v_common_columns).difference(p_update_audit_columns_dict.keys()))

            # Remove key Columns for Update
            v_common_columns_update = list(set(v_common_columns).difference(p_keys))

            # Convert Lists to Dictionaries and Add Audit Columns for Create and Update
            v_common_columns_dict = {column: column for column in v_common_columns}
            v_common_columns_update_dict = {f"target.{column}": f"source.{column}" for column in v_common_columns_update}
            v_common_columns_dict.update(p_create_audit_columns_dict)
            v_common_columns_update_dict.update(p_update_audit_columns_dict)

            # Build Source Select Statement with Filter if passed
            v_source_table_select_str = 'select ' + ', '.join(v_source_column_list) + ' from {0} {1}'.format(p_source_table, p_source_filter)

            # Check if audit columns are in target and keys are in common columns
            if  all(elem in set(v_target_column_list) for elem in p_create_audit_columns_dict.keys()) and \
                all(elem in set(v_target_column_list) for elem in p_update_audit_columns_dict.keys()) and \
                len(p_keys) > 0 and \
                all(elem in set(v_common_columns) for elem in p_keys):
                    
                # Build Merge Template
                merge_sql_str = "Merge into {0} as target using ({1}) as source on {2} when matched then update set {3} when not matched then insert ({4}) values ({5})"

                # Subtitute Template Values
                merge_sql_str = merge_sql_str.format(p_target_table, 
                                                    v_source_table_select_str, 
                                                    ' and '.join([f"target.{column} = source.{column}" for column in p_keys]), 
                                                    ', '.join([f"{k} = {v}" for k,v in v_common_columns_update_dict.items()]), 
                                                    ', '.join(v_common_columns_dict.keys()), 
                                                    ', '.join(v_common_columns_dict.values())
                                                )
            else:
                raise Exception("Audit columns not in Target or Keys are not in either Source or Target. Keys are case-sensitive. Check Configuration.")
            
            # Execute Merge Statement
            #print("\n\n",merge_sql_str,"\n\n")
            df_merge = spark.sql(merge_sql_str)

            # Collect Merge metrics
            df_merge_result = df_merge.collect()
            if len(df_merge_result) > 0:
                return_dict['num_affected_rows'] = df_merge_result[0]['num_affected_rows']
                return_dict['num_updated_rows'] = df_merge_result[0]['num_updated_rows']
                return_dict['num_inserted_rows'] = df_merge_result[0]['num_inserted_rows']
                return_dict['num_deleted_rows'] = df_merge_result[0]['num_deleted_rows']
            return_dict['status'] = 'success'
        else:
            raise Exception("Source or Target Table does not exist")

        return return_dict
    