## Utility Functions

In [1]:
spark.conf.set("spark.databricks.delta.optimizeWrite.enabled", "true")

In [None]:
import logging
# Standard library
import json
import math
import re
import time
from datetime import datetime, date, timedelta
from typing import Optional, List, Dict, Any, Tuple, Set

# PySpark SQL
from pyspark.sql import Row
from pyspark.sql.dataframe import DataFrame
import pyspark.sql.functions as F
from pyspark.sql.functions import (
    col,
    current_timestamp,
    hash as spark_hash,
    lit,
    regexp_replace,
    to_date,
)
from pyspark.sql.types import (
    DateType,
    DoubleType,
    LongType,
    StringType,
    StructField,
    StructType,
    TimestampType,
)
from pyspark.sql.utils import AnalysisException

# Delta Lake
from delta.tables import DeltaTable
from delta.exceptions import ConcurrentAppendException

In [5]:
print("Data processing utils - version 3.14")

# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(str(mssparkutils.env.getJobId()))

#############################################################################################################

def extract_storage_account_name(url):
    # Split the URL by '.' and get the first part
    return url.split('.')[0].split('//')[-1]


def mount_workspace_point(abfss_path, linked_service_name, mount_point_name):
    mounts = mssparkutils.fs.mounts()

    filtered_mount = next((mount for mount in mounts if mount.mountPoint == mount_point_name), None)

    if filtered_mount is None:
        mount_abfss_path = ""
        mount_linkedService = ""
    else:
        mount_abfss_path = filtered_mount.source
        mount_linkedService = filtered_mount.linkedService

    ##Compare if there are no changes in moint point settings
    if (mount_abfss_path != abfss_path or mount_linkedService != linked_service_name):
        
        try:
            mssparkutils.fs.unmount(mount_point)
        except Exception as e:
            pass

        mssparkutils.fs.mount(
            f"{abfss_path}", 
            f"{mount_point_name}", 
            # We use here the linked service credentials
            {"linkedService":f"{linked_service_name}", "scope":"workspace"} 
        )
        
        print(f"Mounted: {mount_point_name} for {abfss_path}. Scope: workspace")
        logger.info(f"Mounted: {mount_point_name} for {abfss_path}. Scope: workspace")
    else:
        print(f"Mount point {mount_point_name} EXIST for {abfss_path}. Scope: workspace")
        logger.info(f"Mount point {mount_point_name} EXIST for {abfss_path}. Scope: workspace")
        
def mount_from_linkedservice(linked_service_name, storage_container_name, scope = "workspace", env = None):
    endpoint = json.loads(mssparkutils.credentials.getPropertiesAll(linked_service_name))['Endpoint']

    storage_account_name = extract_storage_account_name(endpoint)

    abfss_path = f"abfss://{storage_container_name}@{storage_account_name}.dfs.core.windows.net"

    mount_point = get_mount_name(linked_service_name, storage_container_name, scope, env)

    #Unmount if exists and job scope
    if scope=="job":
        try:
            mssparkutils.fs.unmount(mount_point)
        except Exception as e:
            pass


    max_retries = 3
    for attempt in range(max_retries):
        try:
           
            if scope=="workspace":
                # Pernament mount
                mount_workspace_point(abfss_path, linked_service_name, mount_point)
             
            else:
                # Temporarily mount the prod container using to access all the required fields
                mssparkutils.fs.mount(
                    f"{abfss_path}", 
                    f"{mount_point}", 
                    # We use here the linked service credentials
                    {"linkedService":f"{linked_service_name}", "scope":"job"} 
                )
                print(f"Mounted: {mount_point} for {abfss_path}. Scope: {scope}")
                logger.info(f"Mounted: {mount_point} for {abfss_path}. Scope: {scope}")
            break
        except Exception as e:
            print(f"Write conflict detected on attempt {attempt+1}. Error: {e}. Retrying...")
            #Unmount if exists
            try:
                mssparkutils.fs.unmount(mount_point)
            except Exception as e:
                pass
            time.sleep(30 ** attempt)  # Exponential backoff

    spark_version = spark.version
    main_version = ".".join(spark_version.split(".")[:2])
    job_id = mssparkutils.env.getJobId()


    source_file_base_path = get_workspace_mount_synfs_path(mount_point, scope, env)
   
    # return mount and synfs path
    return mount_point, source_file_base_path

#Get synfs path for workspace mounts
def get_workspace_mount_synfs_path(mount_point, scope = "job", env = None):
    spark_version = spark.version
    main_version = ".".join(spark_version.split(".")[:2])
    job_id = mssparkutils.env.getJobId()

    if scope=="workspace":
        synfs_path = f"synfs:/workspace{mount_point}"
    else:
        # Define synfs path for spark version
        if (main_version=='3.3'):
            synfs_path = f"synfs:/{job_id}{mount_point}"
        elif (main_version=='3.4'):
            synfs_path = f"synfs:/notebook/{job_id}{mount_point}"
        else:
            synfs_path = f"synfs:/notebook/{job_id}{mount_point}"

    return synfs_path

def get_mount_name(linked_service_name, storage_container_name, scope = "job", env = None):
    if env == None:
        mount_point = "/"+scope+"/prod/"+linked_service_name+"/"+storage_container_name+"/"
    else: 
        mount_point = "/"+scope+"/"+env+"/"+linked_service_name+"/"+storage_container_name+"/"

    return mount_point

# Function to check if a Delta table exists
def delta_table_exists(tableName):
    try:
        spark.table(tableName).limit(1).collect()
        return True
    except:
        return False  

def get_latest_path_from_yyyy_mm_dd(file_path):
    for _ in range(0,3):
        latest_folder = max(mssparkutils.fs.ls(file_path), key=lambda f:f.name)
        file_path = latest_folder.path
    return file_path

#############################################################################################################

def create_table(metadata_json, dataframe, replace_where = None, table_name_suffix = None , skip_data_lineage = None, vacuum_table = None):

    '''
    Params:
        target_table - name of table if saving as managed table in Synapse or saving to ADLS (linked servie)

        target_schema - schema name if saving as managed table in Synapse

            this will mount connection to ADLS
        target_linked_service - Linked Service name where file file will be created
        target_path - path name where file file will be created
        target_container - container name where file file will be created

    '''
    #Read json metadata
    metadata = json.loads(metadata_json)

    #Get current notebook name
    notebook_name = mssparkutils.runtime.context['currentNotebookName']

    #Get evnironment
    env = None if metadata["env"]=="None" else metadata["env"]
    project_name = metadata["project_name"]
    
    #Overwrite project name with parameter from target_table
    target_project_name = metadata["target_table"].get("project_name",None)
    if target_project_name:
        project_name = target_project_name

    incremental_column_name = metadata.get("incremental_column_name",None)

    admin_schema_name = metadata.get("admin_schema_name","admin")
    admin_table_name = metadata.get("admin_table_name",None)

    partitionBy = metadata["target_table"].get("partitionBy",[])
    format_type = metadata["target_table"].get("format","delta")

    logger.info(f"ENV: {env}")
    print(f"ENV: {env}")
    logger.info(f"PROJECT_NAME: {project_name}")
    print(f"PROJECT_NAME: {project_name}")

    #Target table defined only once
    target_schema_name = metadata["target_table"].get("target_schema","") 
    target_table_name = metadata["target_table"].get("target_table","")

    if table_name_suffix:
        target_table_name = target_table_name + table_name_suffix

    target_ignore_env = metadata["target_table"].get("params",{}).get("ignore_env",False) 
    target_ignore_project = metadata["target_table"].get("params",{}).get("ignore_project",False) 
    target_drop_table = metadata["target_table"].get("params",{}).get("drop_table",False) 

    target_linked_service = metadata["target_table"].get("target_linked_service","") 
    target_path = metadata["target_table"].get("target_path","") 
    target_container = metadata["target_table"].get("target_container","") 

    synapse_target_table_name = get_table_env_name(target_schema_name, target_table_name, env, project_name, target_ignore_env, target_ignore_project)


    #Operation type
    # Support both correct "operation_type" and legacy typo "opertation_type"
    opertation_type = metadata["target_table"].get("operation_type", metadata["target_table"].get("opertation_type", None))

    print(opertation_type)
    logger.info(opertation_type)
    if partitionBy:
        print(f"Partition by: {partitionBy}")


    ####        ---------------------------------------------------------------------
    ####        ---------------------   "Managed Table Type"    ---------------------

    if opertation_type == "Managed Table Type":

        if target_drop_table:
            ## Dropping the table
            print(f"Drop table {synapse_target_table_name}")
            logger.info(f"Drop table {synapse_target_table_name}")
            spark.sql(
                f"DROP TABLE IF EXISTS {synapse_target_table_name}"
            )  # drop every time for now until everything will be checked (to avoid mergeschema)

        try:

            print(f"Create {format_type} table: {synapse_target_table_name}")
            logger.info(f"Create {format_type} table: {synapse_target_table_name}")
            if format_type == 'parquet':
                # Save as parquet
                dataframe.write.mode("overwrite").option("overwriteSchema", "true").partitionBy(*partitionBy).saveAsTable(synapse_target_table_name)
            elif format_type == 'delta':
                # Save as delta
                dataframe.write.format("delta").mode("overwrite").option("parquet.vorder.enabled ","true").option("overwriteSchema", "true").partitionBy(*partitionBy).saveAsTable(synapse_target_table_name)

        except AnalysisException as e:
            if "Can not create the managed table" in str(e):
                #match = re.search(r'abfss://[\w\.-]+/[\w\.-]+', str(e))
                #abfss_path = match.group(0)
                #print(abfss_path)
                
                #Regular expression to extract the ABFSS path
                pattern = r"The associated location\('([^']+)'\)"

                # Search for the pattern in the error message
                match = re.search(pattern, str(e))
                if match:
                    abfss_path = match.group(1)
                    print(f"Deleting files from ABFSS path: {abfss_path}")
                    mssparkutils.fs.rm(abfss_path, True)
                    print("Existing files deleted. You can now recreate the table.")
                else:
                    print("ABFSS path could not be extracted.")
                
            #Create if location exists but managed table does not
            print(f"Create {format_type} table: {synapse_target_table_name}")
            logger.info(f"Create {format_type} table: {synapse_target_table_name}")

            if format_type == 'parquet':
                # Save as parquet
                dataframe.write.mode("overwrite").option("overwriteSchema", "true").partitionBy(*partitionBy).saveAsTable(synapse_target_table_name)
            elif format_type == 'delta':
                # Save as delta
                dataframe.write.format("delta").mode("overwrite").option("overwriteSchema", "true").option("parquet.vorder.enabled ","true").partitionBy(*partitionBy).saveAsTable(synapse_target_table_name)

        ## Validating if table is populated
        ## This would result to an error; halting the cell
        assert spark.sql(f"SELECT COUNT(*) as total_rows from {synapse_target_table_name}").collect()[0]['total_rows'] >= 0


    ####        -----------------------------------------------------------------------
    ####        ---------------------    "Linked Service Type"    ---------------------

    elif opertation_type == "Linked Service Type":

        mount_scope = "workspace"
        mount_point_name = get_mount_name(target_linked_service, target_container, mount_scope, env)

        mount_point[mount_point_name], mount_synfs_path[mount_point_name] = mount_from_linkedservice(target_linked_service, target_container, mount_scope, env)

        target_file_synfs_path = f"{mount_synfs_path[mount_point_name]}{target_path}/{target_table_name}"


        if format_type == 'parquet':
            # Save as parquet
            print(f"Save dataframe as parquet to: {target_file_synfs_path}")
            logger.info(f"Save dataframe as parquet to: {target_file_synfs_path}")
            dataframe.coalesce(1).write.mode("overwrite").option("overwriteSchema", "true").partitionBy(*partitionBy).parquet(target_file_synfs_path)
        elif format_type == 'delta':
            # Save as delta
            print(f"Save dataframe as delta to: {target_file_synfs_path}")
            logger.info(f"Save dataframe as delta to: {target_file_synfs_path}")
            dataframe.coalesce(1).write.format("delta").option("parquet.vorder.enabled ","true").mode("overwrite").option("overwriteSchema", "true").partitionBy(*partitionBy).save(target_file_synfs_path)
        elif format_type == 'csv':
            # Save as csv
            print(f"Save dataframe as csv to: {target_file_synfs_path}")
            logger.info(f"Save dataframe as csv to: {target_file_synfs_path}")
            csv_options = metadata["target_table"].get("csv_options",None) 
            count_rows = dataframe.count()
            if count_rows == 0:
                dataframe.coalesce(1).write.mode("overwrite").options(**csv_options).csv(target_file_synfs_path)
            else: 
                dataframe.coalesce(1).write.mode("overwrite").options(**csv_options).partitionBy(*partitionBy).csv(target_file_synfs_path)
        elif format_type == 'xlsx':
            # Save as excel
            print(f"Save dataframe as excel to: {target_file_synfs_path}")
            logger.info(f"Save dataframe as excel to: {target_file_synfs_path}")
            xlsx_options = metadata["target_table"].get("xlsx_options",None)
            spark.sparkContext.setLogLevel("ERROR")
            dataframe.write.format("com.crealytics.spark.excel").options(**xlsx_options).mode("overwrite").save(f"{target_file_synfs_path}.xlsx")
        elif format_type == 'big_xlsx':
            # Save as excel: Generate xlsx file using xlsxwriter with ZIP64 extensions enabled - Use this when the excel file is too big ~ 300k MiB - 400k+ MiB
            print(f"Save dataframe as excel to: {target_file_synfs_path}")
            logger.info(f"Save dataframe as excel to: {target_file_synfs_path}")
            xlsx_options = metadata["target_table"].get("xlsx_options",None)

            try:
                # Convert to pandas DataFrame
                import pandas as pd
                pandas_df = dataframe.toPandas().reset_index(drop=True)

                # Define the final destination path
                endpoint = json.loads(mssparkutils.credentials.getPropertiesAll(target_linked_service))['Endpoint']

                storage_account_name = extract_storage_account_name(endpoint)

                abfss_path = f"abfss://{target_container}@{storage_account_name}.dfs.core.windows.net{target_path}/{target_table_name}"

                excel_path =  f"{abfss_path}.xlsx"
                print(f"The excel path is: {excel_path}")

                with pd.ExcelWriter(excel_path, engine='xlsxwriter') as writer:
                    pandas_df.to_excel(writer, **xlsx_options)
                    # Enable ZIP64
                    writer.book.use_zip64()

            except Exception as e:
                print(f"Error during Excel conversion: {str(e)}")
                raise


    ####        ------------------------------------------------------------------------------------
    ####        ---------------------   "Managed Table Type - Merge by Key"    ---------------------

    elif opertation_type == "Managed Table Type - Merge by Key":   

        merge_condition = metadata["target_table"].get("merge_condition",None)

        target_table_exists = delta_table_exists(synapse_target_table_name)
    
        if not target_table_exists:
            #Create if location exists but managed table does not
            print(f"Create {format_type} table: {synapse_target_table_name}")
            logger.info(f"Create {format_type} table: {synapse_target_table_name}")
            dataframe.write.format("delta").option("parquet.vorder.enabled ","true").mode("overwrite").partitionBy(*partitionBy).saveAsTable(synapse_target_table_name)
            #Set table exists to true
            target_table_exists = True
            print("Table created") 
            logger.info("Table created") 

        else:
            target_delta_table = DeltaTable.forName(spark, synapse_target_table_name)

            source_columns = set(dataframe.columns)
            target_columns = set(target_delta_table.toDF().columns)

            # Find columns that are in the source but not in the target
            new_columns = source_columns - target_columns
            # Find columns that are in the target but not in the source
            deleted_columns = target_columns - source_columns

            # Check if there are new columns
            if len(new_columns)>0:
                print(f"New columns found: {new_columns}")
                logger.info(f"New columns found: {new_columns}")
                #Just add new columns without any data
                dataframe.limit(0)\
                                .write\
                                .format("delta")\
                                .option("mergeSchema", "true")\
                                .mode("append")\
                                .saveAsTable(synapse_target_table_name)
                target_delta_table = DeltaTable.forName(spark, synapse_target_table_name)

            # Check if there are new columns
            if len(deleted_columns)>0:
                
                df_drop_columns = target_delta_table.toDF()
                # Remove columns from Delta table
                for col_name in deleted_columns:
                    df_drop_columns = df_drop_columns.drop(col(col_name))
                    print(f"Remove columns: {col_name}")
                    logger.info(f"Remove columns: {col_name}")

                #Overwrite delta table with new schema
                df_drop_columns\
                            .write\
                            .format("delta")\
                            .mode("overwrite")\
                            .option("overwriteSchema", "true")\
                            .saveAsTable(synapse_target_table_name)
                target_delta_table = DeltaTable.forName(spark, synapse_target_table_name)
                print(target_delta_table.toDF().columns)
                logger.info(target_delta_table.toDF().columns)

            print(f"Merge to table: {synapse_target_table_name}")
            logger.info(f"Merge to table: {synapse_target_table_name}")

            #Merge to target file
            target_delta_table.alias("target").merge(
                dataframe.alias('source'),
                merge_condition
                ).whenNotMatchedInsertAll().whenMatchedUpdateAll().execute()


    ####        ---------------------------------------------------------------------
    ####        ---------------------   "Managed Table Type - Append"    ------------

    if opertation_type == "Managed Table Type - Append":

        try:

            print(f"Append data to: {synapse_target_table_name} {format_type} table")
            logger.info(f"Append data to: {synapse_target_table_name} {format_type} table")
            if format_type == 'parquet':
                # Save as parquet
                dataframe.write.mode("append").option("mergeSchema", "true").partitionBy(*partitionBy).saveAsTable(synapse_target_table_name)
            elif format_type == 'delta':
                # Save as delta
                dataframe.write.format("delta").mode("append").option("mergeSchema", "true").partitionBy(*partitionBy).saveAsTable(synapse_target_table_name)

        except AnalysisException as e:
            if "Can not create the managed table" in str(e):
                #match = re.search(r'abfss://[\w\.-]+/[\w\.-]+', str(e))
                #abfss_path = match.group(0)
                #print(abfss_path)
                
                #Regular expression to extract the ABFSS path
                pattern = r"The associated location\('([^']+)'\)"

                # Search for the pattern in the error message
                match = re.search(pattern, str(e))
                if match:
                    abfss_path = match.group(1)
                    print(f"Deleting files from ABFSS path: {abfss_path}")
                    mssparkutils.fs.rm(abfss_path, True)
                    print("Existing files deleted. You can now recreate the table.")
                else:
                    print("ABFSS path could not be extracted.")
                
            #Create if location exists but managed table does not
            print(f"Append data to: {synapse_target_table_name} {format_type} table")
            logger.info(f"Append data to: {synapse_target_table_name} {format_type} table")

            if format_type == 'parquet':
                # Save as parquet
                dataframe.write.mode("append").option("mergeSchema", "true").partitionBy(*partitionBy).saveAsTable(synapse_target_table_name)
            elif format_type == 'delta':
                # Save as delta
                dataframe.write.format("delta").mode("append").option("mergeSchema", "true").partitionBy(*partitionBy).saveAsTable(synapse_target_table_name)

        if vacuum_table:
            spark.conf.set("spark.databricks.delta.retentionDurationCheck.enabled", "false")
            spark.sql(f"VACUUM {synapse_target_table_name} RETAIN 0 HOURS")

        ## Validating if table is populated
        ## This would result to an error; halting the cell
        assert spark.sql(f"SELECT COUNT(*) as total_rows from {synapse_target_table_name}").collect()[0]['total_rows'] >= 0


    elif opertation_type == "Managed Table Type - Repace where":

        target_table_exists = delta_table_exists(synapse_target_table_name)

        if not target_table_exists:
            
            #Create if location exists but managed table does not
            print(f"Create {format_type} table: {synapse_target_table_name}")
            logger.info(f"Create {format_type} table: {synapse_target_table_name}")
            dataframe.write.format("delta").option("parquet.vorder.enabled ","true").mode("overwrite").partitionBy(*partitionBy).saveAsTable(synapse_target_table_name)
            #Set table exists to true
            target_table_exists = True
            print("Table created") 
            logger.info("Table created")

        else:
            target_delta_table = DeltaTable.forName(spark, synapse_target_table_name)

            source_columns = set(dataframe.columns)
            target_columns = set(target_delta_table.toDF().columns)

            # Find columns that are in the source but not in the target
            new_columns = source_columns - target_columns
            # Find columns that are in the target but not in the source
            deleted_columns = target_columns - source_columns

            # Check if there are new columns
            if len(new_columns)>0:
                print(f"New columns found: {new_columns}")
                logger.info(f"New columns found: {new_columns}")
                #Just add new columns without any data
                dataframe.limit(0)\
                                .write\
                                .format("delta")\
                                .option("mergeSchema", "true")\
                                .mode("append")\
                                .saveAsTable(synapse_target_table_name)
                target_delta_table = DeltaTable.forName(spark, synapse_target_table_name)

            # Check if there are new columns
            if len(deleted_columns)>0:
                
                df_drop_columns = target_delta_table.toDF()
                # Remove columns from Delta table
                for col_name in deleted_columns:
                    df_drop_columns = df_drop_columns.drop(col(col_name))
                    print(f"Remove columns: {col_name}")
                    logger.info(f"Remove columns: {col_name}")

                #Overwrite delta table with new schema
                df_drop_columns\
                            .write\
                            .format("delta")\
                            .mode("overwrite")\
                            .option("overwriteSchema", "true")\
                            .saveAsTable(synapse_target_table_name)
                target_delta_table = DeltaTable.forName(spark, synapse_target_table_name)
                print(target_delta_table.toDF().columns)
                logger.info(target_delta_table.toDF().columns)

            #Create if location exists but managed table does not
            print(f"Create {format_type} table: {synapse_target_table_name} replace where {replace_where}")
            logger.info(f"Create {format_type} table: {synapse_target_table_name} replace where {replace_where}")


            dataframe.write.format("delta").option("parquet.vorder.enabled ","true").mode("overwrite").option("replaceWhere", replace_where).partitionBy(*partitionBy).saveAsTable(finalTableName)

        ## Validating if table is populated
        ## This would result to an error; halting the cell
        assert spark.sql(f"SELECT COUNT(*) as total_rows from {synapse_target_table_name}").collect()[0]['total_rows'] >= 0


    ####        ---------------------------------------------------------------------
    ####        -----------------   "Incremental Table"    --------------------------

    elif opertation_type == "Incremental Table":
        """
        Incremental table operation - detects Insert/Update/Delete changes between source and target.
        Uses metadata structure with incremental_params:
        {
            "id_key_column": "unique_id",
            "included_columns_for_hash": [...],
            "excluded_columns_for_hash": [...],
            "log_history": true,
            "history_table_name": "...",
            "history_retention_days": 30,
            "ignore_new_columns_as_change": true
        }
        """

        # Extract incremental parameters from metadata
        incremental_params = metadata["target_table"].get("incremental_params", {})

        id_key_column = incremental_params.get("id_key_column")
        if not id_key_column:
            raise ValueError("Incremental Table operation requires 'id_key_column' in incremental_params")

        included_columns_for_hash = incremental_params.get("included_columns_for_hash", None)
        excluded_columns_for_hash = incremental_params.get("excluded_columns_for_hash", None)
        log_history = incremental_params.get("log_history", False)
        history_table_name = incremental_params.get("history_table_name", None)
        history_retention_days = incremental_params.get("history_retention_days", None)
        ignore_new_columns_as_change = incremental_params.get("ignore_new_columns_as_change", True)

        # Get additional create_table params
        create_table_params = metadata["target_table"].get("params", {})
        skip_data_lineage_param = create_table_params.get("skip_data_lineage", False)

        # Override with function parameter if provided
        if skip_data_lineage is not None:
            skip_data_lineage_param = skip_data_lineage

        print(f"Incremental Table operation for {synapse_target_table_name}")
        print(f"  - ID key column: {id_key_column}")
        print(f"  - Log history: {log_history}")
        if log_history and history_table_name:
            print(f"  - History table: {history_table_name}")

        logger.info(f"Incremental Table operation for {synapse_target_table_name}")
        logger.info(f"  ID key: {id_key_column}, log_history: {log_history}")

        # Check if target table exists
        target_exists = delta_table_exists(synapse_target_table_name)

        # ====================================================================================
        # PHASE 1: FULL LOAD (target doesn't exist)
        # ====================================================================================
        if not target_exists:
            print(f"Target table does not exist. Performing FULL LOAD...")
            logger.info(f"Target table does not exist. Performing FULL LOAD for {synapse_target_table_name}")

            # Add operation_type and last_update_dt columns for full load
            from pyspark.sql import functions as F

            df_full_load = dataframe\
                .withColumn("operation_type", F.lit("I"))\
                .withColumn("last_update_dt", F.current_timestamp())

            print(f"  - Creating target table with {df_full_load.count()} records")
            logger.info(f"Creating target table with {df_full_load.count()} records")

            # Create the table using existing create_table logic
            # Build metadata for full load
            metadata_for_full = metadata.copy()
            metadata_for_full["target_table"] = metadata["target_table"].copy()
            metadata_for_full["target_table"]["opertation_type"] = "Managed Table Type"
            metadata_for_full["target_table"]["operation_type"] = "Managed Table Type"

            # Use source_tables from metadata
            metadata_json_full = json.dumps(metadata_for_full)

            # Write using standard Managed Table Type
            df_full_load.write.format("delta")\
                .mode("overwrite")\
                .option("parquet.vorder.enabled", "true")\
                .option("overwriteSchema", "true")\
                .partitionBy(*partitionBy)\
                .saveAsTable(synapse_target_table_name)

            print(f"✓ Full load completed for {synapse_target_table_name}")
            logger.info(f"Full load completed for {synapse_target_table_name}")

            # Set return dataframe
            dataframe = df_full_load

        # ====================================================================================
        # PHASE 2: DELTA LOAD (target exists)
        # ====================================================================================
        else:
            print(f"Target table exists. Performing DELTA LOAD...")
            logger.info(f"Target table exists. Performing DELTA LOAD for {synapse_target_table_name}")

            from pyspark.sql import functions as F
            from pyspark.sql.window import Window
            from delta.tables import DeltaTable
            import hashlib

            # Load target table
            target_delta_table = DeltaTable.forName(spark, synapse_target_table_name)
            df_target = target_delta_table.toDF()

            # Get source from dataframe (already created by make_env_tables)
            df_source = dataframe

            # -------------------------------------------------------------------------
            # STEP 1: Handle schema evolution - add new columns to target
            # -------------------------------------------------------------------------
            source_columns = set(df_source.columns)
            target_columns = set(df_target.columns)

            # Remove metadata columns from comparison
            metadata_columns = {"operation_type", "last_update_dt", "update_type"}
            source_columns_for_comparison = source_columns - metadata_columns
            target_columns_for_comparison = target_columns - metadata_columns

            new_columns = source_columns_for_comparison - target_columns_for_comparison

            if len(new_columns) > 0:
                print(f"  - New columns detected: {new_columns}")
                logger.info(f"New columns detected in source: {new_columns}")

                # Add new columns to target with NULL values
                df_source.limit(0)\
                    .write\
                    .format("delta")\
                    .option("mergeSchema", "true")\
                    .mode("append")\
                    .saveAsTable(synapse_target_table_name)

                # Reload target
                target_delta_table = DeltaTable.forName(spark, synapse_target_table_name)
                df_target = target_delta_table.toDF()

                print(f"  - New columns added to target table")
                logger.info(f"New columns added to target table")

            # -------------------------------------------------------------------------
            # STEP 2: Determine hash columns
            # -------------------------------------------------------------------------
            # Get all columns excluding metadata
            all_source_cols = [c for c in df_source.columns if c not in metadata_columns]

            # Determine columns to track for changes
            if included_columns_for_hash:
                tracked_columns = [c for c in included_columns_for_hash if c in all_source_cols]
            else:
                tracked_columns = all_source_cols.copy()

            # Apply exclusions
            if excluded_columns_for_hash:
                tracked_columns = [c for c in tracked_columns if c not in excluded_columns_for_hash]

            # If new columns should not trigger changes, remove them from tracked columns
            if ignore_new_columns_as_change and len(new_columns) > 0:
                tracked_columns = [c for c in tracked_columns if c not in new_columns]

            print(f"  - Tracking {len(tracked_columns)} columns for changes")
            logger.info(f"Tracking columns for changes: {tracked_columns}")

            # -------------------------------------------------------------------------
            # STEP 3: Create hashed DataFrames
            # -------------------------------------------------------------------------
            def create_hash_column(df, columns_to_hash, hash_col_name):
                """Create hash column from specified columns"""
                if not columns_to_hash:
                    return df.withColumn(hash_col_name, F.lit(None))

                # Concatenate all columns and create MD5 hash
                concat_expr = F.concat_ws("||", *[F.coalesce(F.col(c).cast("string"), F.lit("NULL")) for c in columns_to_hash])
                return df.withColumn(hash_col_name, F.md5(concat_expr))

            # Hash for source
            df_source_hashed = create_hash_column(df_source, tracked_columns, "_hash_tracked")
            df_source_hashed = create_hash_column(df_source_hashed, all_source_cols, "_hash_all")

            # Hash for target (only active records: I or U)
            df_target_active = df_target.filter(F.col("operation_type").isin(["I", "U"]))
            df_target_hashed = create_hash_column(df_target_active, tracked_columns, "_hash_tracked")
            df_target_hashed = create_hash_column(df_target_hashed, all_source_cols, "_hash_all")

            # -------------------------------------------------------------------------
            # STEP 4: Detect INSERTIONS (in source, not in target)
            # -------------------------------------------------------------------------
            df_insertions = df_source_hashed\
                .join(df_target_hashed, df_source_hashed[id_key_column] == df_target_hashed[id_key_column], "left_anti")\
                .select(df_source_hashed["*"])\
                .withColumn("operation_type", F.lit("I"))\
                .withColumn("last_update_dt", F.current_timestamp())\
                .withColumn("update_type", F.lit("Insert"))

            num_insertions = df_insertions.count()
            print(f"  - Insertions detected: {num_insertions}")
            logger.info(f"Insertions: {num_insertions}")

            # -------------------------------------------------------------------------
            # STEP 5: Detect DELETIONS (in target, not in source)
            # -------------------------------------------------------------------------
            df_deletions = df_target_hashed\
                .join(df_source_hashed, df_target_hashed[id_key_column] == df_source_hashed[id_key_column], "left_anti")\
                .select(df_target_hashed["*"])\
                .withColumn("operation_type", F.lit("D"))\
                .withColumn("update_type", F.lit("Delete"))
                # Keep original last_update_dt for deletions

            num_deletions = df_deletions.count()
            print(f"  - Deletions detected: {num_deletions}")
            logger.info(f"Deletions: {num_deletions}")

            # -------------------------------------------------------------------------
            # STEP 6: Detect UPDATES (tracked columns changed)
            # -------------------------------------------------------------------------
            df_updates_tracked = df_source_hashed.alias("src")\
                .join(df_target_hashed.alias("tgt"),
                      F.col("src." + id_key_column) == F.col("tgt." + id_key_column),
                      "inner")\
                .filter(F.col("src._hash_tracked") != F.col("tgt._hash_tracked"))\
                .select("src.*")\
                .withColumn("operation_type", F.lit("U"))\
                .withColumn("last_update_dt", F.current_timestamp())\
                .withColumn("update_type", F.lit("Update-tracked"))

            num_updates_tracked = df_updates_tracked.count()
            print(f"  - Updates (tracked columns): {num_updates_tracked}")
            logger.info(f"Updates (tracked): {num_updates_tracked}")

            # -------------------------------------------------------------------------
            # STEP 7: Detect UPDATES (only untracked columns changed)
            # -------------------------------------------------------------------------
            df_updates_untracked = df_source_hashed.alias("src")\
                .join(df_target_hashed.alias("tgt"),
                      F.col("src." + id_key_column) == F.col("tgt." + id_key_column),
                      "inner")\
                .filter(
                    (F.col("src._hash_tracked") == F.col("tgt._hash_tracked")) &
                    (F.col("src._hash_all") != F.col("tgt._hash_all"))
                )\
                .select([F.col("src." + c).alias(c) for c in df_source_hashed.columns])\
                .withColumn("operation_type", F.lit("U"))\
                .withColumn("last_update_dt", F.col("tgt.last_update_dt"))\  # Keep original timestamp
                .withColumn("update_type", F.lit("Update-untracked"))

            num_updates_untracked = df_updates_untracked.count()
            print(f"  - Updates (untracked columns only): {num_updates_untracked}")
            logger.info(f"Updates (untracked): {num_updates_untracked}")

            # -------------------------------------------------------------------------
            # STEP 8: Detect REACTIVATIONS (previously deleted records returning)
            # -------------------------------------------------------------------------
            df_target_deleted = df_target.filter(F.col("operation_type") == "D")

            df_reactivations = df_source_hashed.alias("src")\
                .join(df_target_deleted.alias("tgt"),
                      F.col("src." + id_key_column) == F.col("tgt." + id_key_column),
                      "inner")\
                .select("src.*")\
                .withColumn("operation_type", F.lit("U"))\
                .withColumn("last_update_dt", F.current_timestamp())\
                .withColumn("update_type", F.lit("Reactivate"))

            num_reactivations = df_reactivations.count()
            if num_reactivations > 0:
                print(f"  - Reactivations detected: {num_reactivations}")
                logger.info(f"Reactivations: {num_reactivations}")

            # -------------------------------------------------------------------------
            # STEP 9: Combine all changes
            # -------------------------------------------------------------------------
            df_all_changes = df_insertions\
                .union(df_updates_tracked)\
                .union(df_updates_untracked)\
                .union(df_reactivations)\
                .union(df_deletions)

            # Drop hash columns
            df_all_changes = df_all_changes.drop("_hash_tracked", "_hash_all")

            total_changes = df_all_changes.count()
            print(f"  - Total changes to apply: {total_changes}")
            logger.info(f"Total changes: {total_changes}")

            if total_changes == 0:
                print(f"✓ No changes detected. Target table is up to date.")
                logger.info(f"No changes detected for {synapse_target_table_name}")
                dataframe = None
            else:
                # -------------------------------------------------------------------------
                # STEP 10: Log to history table (optional)
                # -------------------------------------------------------------------------
                if log_history and history_table_name:
                    print(f"  - Logging changes to history table: {history_table_name}")
                    logger.info(f"Logging to history: {history_table_name}")

                    # Get fully qualified history table name
                    history_schema = metadata["target_table"].get("target_schema", target_schema_name)
                    history_table_full = get_table_env_name(history_schema, history_table_name, env, project_name, target_ignore_env, target_ignore_project)

                    # Add audit columns
                    df_history = df_all_changes\
                        .withColumn("_audit_timestamp", F.current_timestamp())\
                        .withColumn("_audit_operation", F.col("update_type"))

                    # Append to history table
                    df_history.write\
                        .format("delta")\
                        .mode("append")\
                        .option("mergeSchema", "true")\
                        .saveAsTable(history_table_full)

                    print(f"  ✓ Logged {total_changes} changes to history")
                    logger.info(f"Logged {total_changes} changes to history table")

                    # Apply retention policy if specified
                    if history_retention_days:
                        retention_date = F.current_date() - F.expr(f"INTERVAL {history_retention_days} DAYS")
                        history_delta = DeltaTable.forName(spark, history_table_full)
                        history_delta.delete(F.col("_audit_timestamp") < retention_date)
                        print(f"  ✓ Applied retention policy: {history_retention_days} days")
                        logger.info(f"Applied retention policy: {history_retention_days} days")

                # -------------------------------------------------------------------------
                # STEP 11: Merge changes to target table
                # -------------------------------------------------------------------------
                print(f"  - Merging changes to target table...")
                logger.info(f"Merging {total_changes} changes to {synapse_target_table_name}")

                # Prepare final DataFrame for merge (drop update_type helper column)
                df_for_merge = df_all_changes.drop("update_type")

                # Create merge condition
                merge_condition = f"target.{id_key_column} = source.{id_key_column}"

                # Perform merge
                target_delta_table.alias("target")\
                    .merge(df_for_merge.alias("source"), merge_condition)\
                    .whenMatchedUpdateAll()\
                    .whenNotMatchedInsertAll()\
                    .execute()

                print(f"✓ Delta load completed for {synapse_target_table_name}")
                print(f"  Summary: {num_insertions} inserts, {num_updates_tracked + num_updates_untracked} updates, {num_deletions} deletes, {num_reactivations} reactivations")
                logger.info(f"Delta load completed: {num_insertions}I, {num_updates_tracked + num_updates_untracked}U, {num_deletions}D, {num_reactivations}R")

                # Set return dataframe
                dataframe = df_all_changes

        # Validate table is populated
        row_count = spark.sql(f"SELECT COUNT(*) as cnt FROM {synapse_target_table_name}").collect()[0]['cnt']
        print(f"  - Final row count: {row_count}")
        logger.info(f"Final row count in {synapse_target_table_name}: {row_count}")
        assert row_count >= 0


    #If incremetnal load then add row to metadata 
    if incremental_column_name:
        ## Log to metadata max of processing data from target table (column incremental_column_name)

        processing_date_to_log = get_max_of_incremental_column_name(metadata_json)

        log_data = get_log_data(metadata_json, processing_date_to_log)

        log_to_metadata(log_data,metadata_json,admin_schema_name, admin_table_name)

    if not skip_data_lineage:
        log_to_datalineage(metadata_json, admin_schema_name, data_lineage_table_name)


#############################################################################################################

def make_env_tables(metadata_json):
    
    """
    Instantiates temporary views and defines a target table from JSON metadata, supporting environment
    and project customization. It enables specific source tables to bypass global environment (`env`)
    or project name (`project_name`) settings, useful for integrating tables that do not conform to
    standard naming or partitioning schemes.

    Parameters:
    - metadata_json (str): JSON string with details of target and source tables, environment (`env`),
      and project name (`project_name`). `source_tables` includes each table's details and optionally,
      a flag to ignore environment (`ignore_env`) or project (`ignore_project`) for that table.

    The `ignore_env` and `ignore_project` options allow individual source tables to exclude global
    environment or project context, facilitating flexibility in data source integration.

    Returns:
    - Dictionary containing :
        - "target_table" key with synapse_target_table_name value
        - "source_tables" with view name + full resolved schema/table names 
        - "skipped_optional_view" : list of views marked as 'optional' and which couldn't be loaded/found
        - additional keys can be added as needed ...
    - structure : {
        "target_table": "<resolved target>",
        "source_tables": { "<view_name>": "<resolved source ref>", ... },
        "skipped_optional_views": [ ... ]
      }
    """

    # Create a result dictionary to return
    result = {"skipped_optional_views": {}, "source_tables": {}}

    #Read json metadata
    metadata = json.loads(metadata_json)

    #Get job id to read data with synfs
    job_id = mssparkutils.env.getJobId()

    #Get project name
    project_name = metadata["project_name"]

    #Get environment
    env = None if metadata["env"]=="None" else metadata["env"]

    target_ignore_env = metadata["target_table"].get("params",{}).get("ignore_env",False) 
    target_ignore_project = metadata["target_table"].get("params",{}).get("ignore_project",False) 

    #Admin table name
    final_admin_table_name =  get_table_env_name(metadata.get("admin_schema_name",""), metadata.get("admin_table_name",""), env, project_name, False, ignore_project = True)
 
    
    target_schema_name = metadata["target_table"].get("target_schema","") 
    target_table_name = metadata["target_table"].get("target_table","")

    target_opertation_type = metadata["target_table"].get("opertation_type",None)

    if target_opertation_type == "Linked Service Type":
        target_linked_service = metadata["target_table"].get("target_linked_service","") 
        target_path = metadata["target_table"].get("target_path","") 
        target_container = metadata["target_table"].get("target_container","") 

        synapse_target_table_name = f"{target_container}/{target_linked_service}/{target_path}/"
    else:
        synapse_target_table_name = get_table_env_name(target_schema_name, target_table_name, env, project_name, target_ignore_env, target_ignore_project)


    #Multiple source tables
    for source_table in metadata["source_tables"]:

        linked_service_name = source_table.get("linked_service_name",None)
        
        #Set parameters
        sql_where = None
        temp_view_name = None
        temp_view_project_name = None
        source_project_name = None 
        source_ignore_env = False
        source_ignore_project = False
        skip = False
        source_ref = None

        temp_view_name = source_table.get("view_name",None)
        if temp_view_name is None:
            temp_view_name = source_table.get("view",None)

        temp_view_project_name = source_table.get("project_name",None)

        #If project not declared then use 
        source_project_name = temp_view_project_name if temp_view_project_name else project_name

        source_ignore_env       = source_table.get("params",{}).get("ignore_env",False) 
        source_ignore_project   = source_table.get("params",{}).get("ignore_project",False) 
        source_optional         = source_table.get("params",{}).get("optional", False)
        
        opertation_type = source_table.get("opertation_type","lake database view")

        try:
            ####        ---------------------------------------------------------------------
            ####        ---------------------   "Linked Service"    -------------------------

            if opertation_type == "Linked Service":
                #Inicialize mount point if not exists in global variables

                container = source_table.get("container",None)
                file_format = source_table.get("file_format","parquet")
                mount_scope = source_table.get("mount_scope","workspace")

                mount_point_name = get_mount_name(linked_service_name, container, mount_scope, env)

                mount_point[mount_point_name], mount_synfs_path[mount_point_name] = mount_from_linkedservice(linked_service_name, container, mount_scope, env)

                path = source_table.get("path",None)
                target_file_synfs_path = f"{mount_synfs_path[mount_point_name]}{path}/"
                    
                if source_table.get("load_type",None) == "lastest by YYYY/MM/DD":
                    target_file_synfs_path = get_latest_path_from_yyyy_mm_dd(target_file_synfs_path)

                if source_table.get("path_suffix",None):
                    target_file_synfs_path = target_file_synfs_path+"/"+source_table.get("path_suffix","")

                print("Create temp view: "+temp_view_name+" as linked service path: "+ target_file_synfs_path)
                logger.info("Create temp view: "+temp_view_name+" as linked service path: "+ target_file_synfs_path)
                source_ref = target_file_synfs_path
                
                if source_optional and not mssparkutils.fs.exists(target_file_synfs_path):
                    skip = True
                else:
                    if file_format=="parquet":
                        df = spark.read.option("ignoreCorruptFiles", "true").parquet(target_file_synfs_path)
                    elif file_format=="delta":
                        df = spark.read.format("delta").load(target_file_synfs_path)
                    elif file_format=="csv":
                        csv_options = source_table.get("csv_options",None)
                        df = spark.read.format("csv").options(**csv_options).load(target_file_synfs_path)
                    else:
                        raise Exception("File format not supported")

                    # df is final product 


            ####        ---------------------------------------------------------------------
            ####        -----------   "delta sharing protocol - spark (JAR)"    -------------

            elif opertation_type == "delta sharing":
                # Connect to delta sharing server 
                # params: key vault linked_service_name, secret with profile to connect to delta sharing protocol, table name

                table_name = source_table.get("table",None) # delta sharing table name  <share name>.<schema name>.<table name>
                key_vaule_ls = source_table.get("key_vaule_ls",None) # azure key vaule linked service name
                secret_with_profile = source_table.get("profile",None)  # akv secret with delta sharing profile (json)

                try:
                    profile = mssparkutils.credentials.getSecretWithLS(key_vaule_ls, secret_with_profile)
                    table_url = profile + "#" + table_name

                except Exception as e:
                    print(f"An error occurred in detla sharing table adress {table_name}")
                    logger.info(f"An error occurred in detla sharing table adress {table_name}")
                    raise

                print("Create temp view: "+temp_view_name+" from delta sharing table: "+ table_name)
                logger.info("Create temp view: "+temp_view_name+" from delta sharing table: "+ table_name)

                df = spark.read.format("deltaSharing").load(table_url)
                source_ref = table_url


            elif opertation_type == "delta sharing 1.2.0":
                # Connect to delta sharing server 
                # params: key vault linked_service_name, secret with profile to connect to delta sharing protocol, table name

                table_name = source_table.get("table",None) # delta sharing table name  <share name>.<schema name>.<table name>
                key_vaule_ls = source_table.get("key_vaule_ls",None) # azure key vaule linked service name
                secret_with_profile = source_table.get("profile",None)  # akv secret with delta sharing profile (json)

                if table_name is None:
                    raise ValueError("table cannot be None")
                if key_vaule_ls is None:
                    raise ValueError("key_vaule_ls cannot be None")
                if secret_with_profile is None:
                    raise ValueError("profile cannot be None")

                table_url = f"{key_vaule_ls};{secret_with_profile}#{table_name}"

                print("Create temp view: "+temp_view_name+" from delta sharing table: "+ table_name)
                logger.info("Create temp view: "+temp_view_name+" from delta sharing table: "+ table_name)

                df = spark.createDataFrame(delta_sharing.load_as_pandas(table_url))
                source_ref = table_url


            ####        ---------------------------------------------------------------------
            ####        ---------------------   "serverless jdbc ls"    ---------------------

            elif opertation_type == "serverless jdbc ls":
                # Connect to serverless database with JDBC
                # params: linked_service_name, view, table name

                table_name = source_table.get("table",None)

                print("Create temp view: "+temp_view_name+" as JDBC linked service "+linked_service_name+" from table: "+ table_name)
                logger.info("Create temp view: "+temp_view_name+" as JDBC linked service "+linked_service_name+" from table: "+ table_name)
                df = getJDBCdataWithLinkedService(linked_service_name, table_name)

            elif opertation_type == "serverless jdbc ls sqlauth":
                # Connect to serverless database with JDBC
                # params: linked_service_name, view, table name

                table_name = source_table.get("table",None)

                print("Create temp view: "+temp_view_name+" as JDBC linked service "+linked_service_name+" from table: "+ table_name)
                logger.info("Create temp view: "+temp_view_name+" as JDBC linked service "+linked_service_name+" from table: "+ table_name)
                df = getJDBCdataWithLinkedServiceSQLAuth(linked_service_name, table_name)
                source_ref = table_url

                # df is final product 


            ####        ------------------------------Default--------------------------------
            ####        ---------------------   "lake database view"    ---------------------

            # Default operation type
            elif opertation_type == "lake database view":
                source_schema_name = source_table.get("schema",None)
                source_table_name = source_table.get("table",None)

                synapse_source_table_name = get_table_env_name(source_schema_name, source_table_name, env, source_project_name, source_ignore_env, source_ignore_project)

                print("Create temp view: "+temp_view_name+" from table: "+ synapse_source_table_name)
                logger.info("Create temp view: "+temp_view_name+" from table: "+ synapse_source_table_name)
                source_ref = synapse_source_table_name

                # If optional, skip cleanly when table is missing
                if source_optional and not delta_table_exists(synapse_source_table_name):
                    skip = True
                else:
                    df = spark.table(synapse_source_table_name)
                
                    # df is final product 


            # DEFINE TEMP VIEW

            if skip:
                print(f"Skip optional view {temp_view_name} : table/path not found")
                logger.info(f"Skip optional view {temp_view_name} : table/path not found")
                result["skipped_optional_views"][temp_view_name] = source_ref
                continue

            #Get last processing date and filter source data
            if source_table.get("load_type",None)=="incremental - by last target table processing date":

                synapse_source_table_name = None
                if source_table.get("opertation_type",None) == "Linked Service":
                    linked_service_name = source_table.get("linked_service_name",None)
                    container = source_table.get("container",None)
                    path = source_table.get("path",None)
                    synapse_source_table_name = f"{container}/{linked_service_name}/{path}/"
                elif source_table.get("opertation_type",None) == "serverless jdbc ls":
                    table_name = source_table.get("table",None)
                    synapse_source_table_name = get_table_env_name(source_schema_name, source_table_name, env, source_project_name, source_ignore_env, source_ignore_project)
                elif source_table.get("opertation_type",None) == "serverless jdbc ls sqlauth":
                    table_name = source_table.get("table",None)
                    synapse_source_table_name = get_table_env_name(source_schema_name, source_table_name, env, source_project_name, source_ignore_env, source_ignore_project)
                else:
                    source_schema_name = source_table.get("schema",None)
                    source_table_name = source_table.get("table",None)
                    synapse_source_table_name = get_table_env_name(source_schema_name, source_table_name, env, source_project_name, source_ignore_env, source_ignore_project)

                source_ignore_incremental_load = source_table.get("ignore_incremental_load","N")

                if source_ignore_incremental_load == "N":
                    # Incremental load
                    # Target table defined only once
                    target_schema_name = project_name+'_'+metadata["target_table"].get("target_schema","") 
                    target_table_name = metadata["target_table"].get("target_table","")

                    process_date = metadata.get("process_date","1900-01-01")
                    incremental_column_name = source_table.get("incremental_column_name","LOAD_DATE")

                    delta_metadata_table = DeltaTable.forName(spark, final_admin_table_name).toDF().createOrReplaceTempView("metadata_temp_view")
                    metadata_df = spark.sql(f"""
                        SELECT  COALESCE(MAX(modify_datetime),'1900-01-01') modify_datetime
                        FROM    metadata_temp_view
                        WHERE   lower(synapse_target_table_name) = lower('{synapse_target_table_name}')
                        AND     lower(synapse_source_table_name) = lower('{synapse_source_table_name}')
                        AND     load_type = 'incremental - by last target table processing date'
                    """)
                    metadata_modify_datetime = metadata_df.collect()[0]['modify_datetime']
                    
                    if process_date < metadata_modify_datetime:
                        print(f"Override processing date with {process_date} value")
                        logger.info(f"Override processing date with {process_date} value")
                        query_processing_datetime = process_date
                    else:
                        query_processing_datetime = metadata_modify_datetime

                    sql_where = f"`{incremental_column_name}` > '{query_processing_datetime}'"

                elif source_ignore_incremental_load == "Y":
                    #Ignore incremental load
                    print(f"Full load -> ignore incremental load")

            else:
                sql_where = source_table.get("sql_where",None)

            if sql_where:
                print(f"Apply sql_where condition: {sql_where}")
                logger.info(f"Apply sql_where condition: {sql_where}")
                df = df.filter(sql_where)

            print(f"Create view {temp_view_name}")
            logger.info(f"Create view {temp_view_name}")
            # Create view
            df.createOrReplaceTempView(temp_view_name)
            if source_ref is not None:
                result["source_tables"][temp_view_name] = source_ref

        except Exception as e:
            msg = str(e)
            if source_optional:
                print(f"Skip optional view {temp_view_name} ; Error while making table {source_table} : {msg}")
                logger.info(f"Skip optional view {temp_view_name} ; Error while making table {source_table} : {msg}")
                result["skipped_optional_views"].append(temp_view_name)
                continue
            print(f"Error while making table={source_table}, error: {msg}")
            raise

    result["target_table"] = synapse_target_table_name
    return result
    
#############################################################################################################

def log_to_metadata(filtered_rows:list, metadata_json, admin_schema_name = None, admin_table_name = None, max_retries = 5):
   
    #Metadata store

    metadata = json.loads(metadata_json)
    project_name = metadata["project_name"]

    #Target table defined only once
    target_schema_name = metadata["target_table"].get("target_schema","") 
    target_table_name = metadata["target_table"].get("target_table","")

    env = None if metadata["env"]=="None" else metadata["env"]

    if not admin_schema_name:
        admin_schema_name = metadata["admin_schema_name"]
    if not admin_table_name:
        admin_table_name = metadata["admin_table_name"]

    finalAdminTableName = get_table_env_name(admin_schema_name, admin_table_name, env, project_name, False, True)
    
    target_opertation_type = metadata["target_table"].get("opertation_type",None)

    if target_opertation_type == "Linked Service Type":
        target_linked_service = metadata["target_table"].get("target_linked_service","") 
        target_path = metadata["target_table"].get("target_path","") 
        target_container = metadata["target_table"].get("target_container","") 
        synapse_target_table_name = f"{target_container}/{target_linked_service}/{target_path}/"
    else:
        synapse_target_table_name = get_table_env_name(target_schema_name, target_table_name, env, project_name, False, False)


    schema = StructType([   
        StructField("project_name", StringType(), False, {'comment': "Name of project"}),
        StructField("load_type", StringType(), True, {'comment': "delta/full"}),
        StructField("synapse_source_table_name", StringType(), True, {'comment': "Name of source table in Synapse"}),
        StructField("synapse_target_table_name", StringType(), True, {'comment': "Name of target table in Synapse"}),
        StructField("params", StringType(), True, {'comment': "Partitions, merge conditions"}),
        StructField("modify_datetime", TimestampType(), True, {'comment': "Source file modification date"}),
        StructField("processing_date", DateType(), False, {'comment': "File processing date"})
    ])

    merge_to_metadata_df = spark.createDataFrame(filtered_rows, schema)

    delta_metadata_table = DeltaTable.forName(spark, finalAdminTableName)

    #Write processing info to metadata column
    for attempt in range(max_retries):
        try:

            delta_metadata_table.alias("metadata").merge(
                merge_to_metadata_df.alias("processed"),
                f"metadata.synapse_source_table_name = processed.synapse_source_table_name and metadata.params = processed.params and metadata.synapse_target_table_name = '{synapse_target_table_name}'"
            ).whenMatchedUpdate(
                condition = "processed.modify_datetime != '1900-01-01'", 
                set = {
                    "project_name": "processed.project_name",
                    "load_type": "processed.load_type",
                    "synapse_source_table_name": "processed.synapse_source_table_name",
                    "synapse_target_table_name": "processed.synapse_target_table_name",
                    "params": "processed.params",
                    "modify_datetime": "processed.modify_datetime",
                    "processing_date": "processed.processing_date"
                }
            ).whenMatchedUpdate(
                condition = "processed.modify_datetime = '1900-01-01'",
                set = {
                    "project_name": "processed.project_name",
                    "load_type": "processed.load_type",
                    "synapse_source_table_name": "processed.synapse_source_table_name",
                    "synapse_target_table_name": "processed.synapse_target_table_name",
                    "params": "processed.params",
                    "processing_date": "processed.processing_date"
                    # Nie aktualizujemy modify_datetime gdy jest 1900-01-01
                }
            ).whenNotMatchedInsertAll().execute()

            print(f"Write metadata successful to {finalAdminTableName}")
            logger.info(f"Write metadata successful to {finalAdminTableName}")
            break
        except ConcurrentAppendException as e:
            print(f"Write conflict detected on attempt {attempt+1}. Error: {e}. Retrying...")
            time.sleep(2 ** attempt)  # Exponential backoff
        except Exception as e:
            print(f"An error occurred: {e}")
            logger.info(f"An error occurred: {e}")
            break

#############################################################################################################

def create_sql_condition(json_strings):
    """
    Generates a SQL condition string from a list of JSON strings that contain a 'partitions' dictionary.
    
    Each JSON string should contain a dictionary with a single key 'partitions', inside of which there is
    another dictionary with key-value pairs. The function identifies all unique keys within 'partitions'
    across all given JSON strings, collects unique values for each key, and generates a part of a SQL
    query that selects records where the key matches any of the collected values. If multiple keys are
    found, their conditions are combined with the 'OR' operator.

    Parameters:
    - json_strings (list of str): A list of strings, where each string is a JSON representation of a dictionary
      with a 'partitions' key.

    Returns:
    - str: A SQL condition string that can be used in a WHERE clause, containing conditions for all unique
      keys found in the input JSON strings, combined with 'OR' if necessary.

    Raises:
    - json.JSONDecodeError: If any of the strings in the input list is not a valid JSON.
    """
    
    unique_values_by_key = {}

    for json_str in json_strings:
        try:
            data = json.loads(json_str)
            for key, value in data["partitions"].items():
                if key not in unique_values_by_key:
                    unique_values_by_key[key] = set()
                unique_values_by_key[key].add(value)
        except json.JSONDecodeError as e:
            raise json.JSONDecodeError(f"Error deserializing JSON: {json_str}") from e

    # Creating SQL conditions for each key
    sql_parts = []
    for key, values in unique_values_by_key.items():
        values_str = ",".join(f"'{v}'" for v in values)
        sql_parts.append(f"`{key}` IN ({values_str})")

    # Combining all conditions with 'OR' if there are multiple keys
    sql_query = " AND ".join(sql_parts)

    return sql_query


#############################################################################################################

def partitions_to_process(metadata_json, admin_schema_name, admin_table_name_to_process, admin_table_name_to_raw):

    #Metadata store
    metadata = json.loads(metadata_json)

    env = None if metadata["env"]=="None" else metadata["env"]
    project_name = metadata["project_name"]

    target_schema_name = project_name+'_'+metadata["target_table"].get("target_schema","") 
    target_table_name = metadata["target_table"].get("target_table","")

    ## Formulating the target table name
    if env:
        admin_schema_name = admin_schema_name +"_"+ env
        target_schema_name = target_schema_name +"_"+ env
    final_admin_table_name_to_process = f"{admin_schema_name}.{admin_table_name_to_process}"
    final_admin_table_name_to_raw = f"{admin_schema_name}.{admin_table_name_to_raw}"
    finalTableName = f"{target_schema_name}.{target_table_name}"

    print("Get partitions to process...")
    logger.info("Get partitions to process...")

    spark.table(final_admin_table_name_to_process).createOrReplaceTempView("admin_process")
    spark.table(final_admin_table_name_to_raw).createOrReplaceTempView("admin_raw")


    schema = StructType([   
        StructField("project_name", StringType(), False, {'comment': "Name of project"}),
        StructField("load_type", StringType(), True, {'comment': "delta/full"}),
        StructField("synapse_source_table_name", StringType(), True, {'comment': "Name of source table in Synapse"}),
        StructField("synapse_target_table_name", StringType(), True, {'comment': "Name of target table in Synapse"}),
        StructField("params", StringType(), True, {'comment': "Partitions, merge conditions"}),
        StructField("modify_datetime", DateType(), True, {'comment': "Source file modification date"}),
        StructField("processing_date", DateType(), False, {'comment': "File processing date"})
    ])
    union_df = spark.createDataFrame([], schema)

    #Multiple source tables
    for source_table in metadata["source_tables"]:

        schema_source_name = source_table.get("schema",None)
        table_source_name = source_table.get("table",None)
        temp_view_name = source_table.get("view",None)
        temp_view_project_name = source_table.get("project_name",None)

        project_name = temp_view_project_name if temp_view_project_name else project_name

        param = source_table.get("params",None) 

        if param:
            env = None if param.get("ignore_env",False) else env
            project_name = None if param.get("ignore_project",False) else project_name

        ## Formulating the target table name

        if project_name:
            schema_source_name = project_name+"_"+schema_source_name

        if env:
            schema_source_name = schema_source_name +"_"+ env
        finalSourceTableName = f"{schema_source_name}.{table_source_name}"


        partition_to_process_df = spark.sql(f"""
                SELECT       r.project_name
                            ,r.load_type
                            ,lower(r.synapse_name) as synapse_source_table_name
                            ,lower('{finalTableName}') as synapse_target_table_name
                            ,r.params
                            ,r.modify_datetime
                            ,r.processing_date
                FROM        admin_raw r 
                LEFT JOIN   admin_process p 
                ON          r.project_name = p.project_name
                AND         r.load_type = p.load_type
                AND         lower(r.synapse_name) = lower(p.synapse_source_table_name)
                AND         lower(synapse_target_table_name) = lower('{finalTableName}')
                AND         r.params = p.params
                AND         r.processing_date = p.processing_date
                WHERE       lower(r.synapse_name) = lower('{finalSourceTableName}')
                AND         p.project_name IS NULL
            """)

        print(f"{finalTableName} -> Partitions to process: {partition_to_process_df.count()}")
        logger.info(f"{finalTableName} -> Partitions to process: {partition_to_process_df.count()}")


        union_df = union_df.union(partition_to_process_df)

    c = union_df.cache().collect()
    sorted_rows = sorted(c, key=lambda row: row.params)

    return c

#####################################################################

def getJDBCdataWithLinkedService(linked_service_name, tableName) -> DataFrame:
    """ 
    Retrieves data from a specified table in a SQL database via JDBC using credentials from a linked service in Azure Synapse.

    This function extracts the database endpoint, name, and access token from the specified linked service. 
    It then constructs a JDBC URL and uses it to read the specified table into a DataFrame using Spark's JDBC capabilities.

    Parameters:
        linked_service_name (str): The name of the linked service in Azure Synapse which contains the connection information.
        tableName (str): The name of the table from which to fetch the data. Ex: "dbo.v_table_name"

    Returns:
        DataFrame: A DataFrame containing the data from the specified table in the database.
    """
    db_properties={}  
    linked_service_params = json.loads(mssparkutils.credentials.getPropertiesAll(linked_service_name))

    endpoint = linked_service_params['Endpoint']
    # The 'Endpoint' represents the fully qualified domain name (FQDN) or IP address of the SQL Server.
    # Example: "sqlserver.database.windows.net"

    database = linked_service_params['Database']
    # The 'Database' specifies the name of the SQL Server database to connect to.
    # Example: "my_database"

    db_properties["accessToken"] = mssparkutils.credentials.getConnectionStringOrCreds(linked_service_name)
    # The "accessToken" retrieves an access token for the database connection from the linked service.
    # This token is used in scenarios where token-based authentication (e.g., Azure Active Directory) is required.
    # It provides a secure alternative to using static credentials like username and password.

    db_properties["driver"] = "com.microsoft.sqlserver.jdbc.SQLServerDriver"
    # The "driver" specifies the JDBC driver class for connecting to SQL Server.
    # In this case, it uses the Microsoft SQL Server JDBC Driver.
    # This driver must be available in the Spark environment to establish a connection.

    print(mssparkutils.credentials.getConnectionStringOrCreds(linked_service_name))
    #print(f"Load data from Database: {database}, table: {tableName}")
    logger.info(f"Load data from Database: {database}, table: {tableName}")

    df = spark.read.jdbc(f"jdbc:sqlserver://{endpoint};databaseName={database}",tableName,properties=db_properties)

    return df

#####################################################################

def getJDBCdataWithLinkedServiceSQLAuth(linked_service_name, tableName) -> DataFrame:
    """ 
    Retrieves data from a specified table in a SQL database via JDBC using credentials from a linked service in Azure Synapse.

    This function extracts connection parameters such as the database endpoint, name, user credentials, and authentication key 
    from the specified linked service in Azure Synapse. It then constructs a JDBC URL and fetches the specified table into 
    a Spark DataFrame using Spark's JDBC capabilities.

    Parameters:
        linked_service_name (str): The name of the linked service in Azure Synapse which contains the connection information.
        tableName (str): The name of the table from which to fetch the data. Ex: "dbo.v_table_name"

    Returns:
        DataFrame: A DataFrame containing the data from the specified table in the database.
    """
    db_properties={}  
    linked_service_params = json.loads(mssparkutils.credentials.getPropertiesAll(linked_service_name))
    print(linked_service_params)
    # Extract connection parameters from the linked service
    endpoint = linked_service_params['Endpoint']
    # The 'Endpoint' represents the fully qualified domain name (FQDN) or IP address of the SQL Server.
    # Example: "sqlserver.database.windows.net"

    database = linked_service_params['Database']
    # The 'Database' specifies the name of the SQL Server database to connect to.
    # Example: "my_database"

    user = linked_service_params['Id']
    # The 'Id' corresponds to the username used for authentication with the SQL Server.
    # This is retrieved from the linked service configuration in Azure Synapse.

    password = linked_service_params['AuthKey']
    # The 'AuthKey' is the password or authentication key associated with the username (Id).
    # It is securely stored in the linked service and fetched programmatically.
    #db_properties["accessToken"] = mssparkutils.credentials.getConnectionStringOrCreds(linked_service_name)

    db_properties["driver"] = "com.microsoft.sqlserver.jdbc.SQLServerDriver"
    # The "driver" specifies the JDBC driver class for connecting to SQL Server.
    # In this case, it uses the Microsoft SQL Server JDBC Driver.
    # This driver must be available in the Spark environment to establish a connection.
    
    print(mssparkutils.credentials.getConnectionStringOrCreds(linked_service_name))
    #print(f"Load data from Database: {database}, table: {tableName}")
    logger.info(f"Load data from Database: {database}, table: {tableName}")

    df = spark.read.jdbc(f"jdbc:sqlserver://{endpoint};databaseName={database};user={user};password={password}",tableName)

    return df


#####################################################################

def format_column_names(df, forbidden_characters="+/? .,;{}()\n\t="):
    """
    Function take dataframe and replaces forbidden characters with underscores in column names.
    It also replaces & sign with _and_.
    It also replaces % sign with _percent_
    Then it removes any leading, trailing and double underscores from column names.

    Parameters:
        df: spark dataframe as input
        forbidden_characters: characters to be removed from column names

    Returns:
        df: formated spark dataframe
    """

    df_columns = df.columns
    forbidden_chars = [
        forbidden_characters[i] for i in range(len(forbidden_characters))
    ]
    for column in df_columns:
        old_column_name = str(column)
        for c in forbidden_chars:  # replace forbidden strings
            column = column.replace(c, "_")
        column = column.strip("_")  # trim `_` at the start and end of the name
        column = column.replace("&", "_and_")
        column = column.replace("%", "_percent_")
        while "__" in column:  # remove any `__` pairs
            column = column.replace("__", "_")
        column = column.lower()
        if column != old_column_name:  # rename column if name has changed
            df = df.withColumnRenamed(old_column_name, column)
    return df

#####################################################################

def get_table_env_name(schema_name, table_name, env, project_name, ignore_env = None, ignore_project = None):
    """
    Generates a fully qualified table name by combining schema, table, environment, and project name information.

    Parameters:
    schema_name (str): The base name of the schema. This is the initial part of the schema name that will be combined 
                       with the project and environment names if provided.
    table_name (str): The name of the table. This is the final part of the fully qualified table name.
    env (str): The environment name (e.g., 'dev', 'prod', 'test'). This will be appended to the schema name 
               unless `ignore_env` is set to `True`.
    project_name (str): The name of the project. This will be prefixed to the schema name unless `ignore_project` is 
                        set to `True`.
    ignore_env (bool, optional): If set to `True`, the environment name will not be appended to the schema name. 
                                 Defaults to `None`.
    ignore_project (bool, optional): If set to `True`, the project name will not be prefixed to the schema name. 
                                     Defaults to `None`.

    Returns:
    str: The fully qualified table name in the format `schema_name.table_name`, where `schema_name` is optionally 
         prefixed by the project name and/or suffixed by the environment name, depending on the provided parameters 
         and flags.

    Function Logic:
    1. Check `ignore_env` and `ignore_project` Flags:
        - If `ignore_env` is `True`, set `env` to `None`.
        - If `ignore_project` is `True`, set `project_name` to `None`.

    2. Formulate the Schema Name:
        - If `project_name` is provided and `ignore_project` is not `True`, prefix `schema_name` with `project_name`.
        - If `env` is provided and `ignore_env` is not `True`, suffix `schema_name` with `env`.

    3. Combine Schema Name and Table Name:
        - Construct the final table name in the format `schema_name.table_name`.

    4. Return the Final Table Name:
        - Return the constructed fully qualified table name.
    """

    if ignore_env:
        env = None
    if ignore_project:
        project_name = None

    ## Formulating the target table name

    if project_name:
        schema_name = project_name+"_"+schema_name

    if env:
        schema_name = schema_name +"_"+ env
    finalTableName = f"{schema_name}.{table_name}"

    return finalTableName

#####################################################################

def get_max_of_incremental_column_name(metadata_json):

    #Metadata store
    metadata = json.loads(metadata_json)

    env = None if metadata["env"]=="None" else metadata["env"]
    project_name = metadata["project_name"]

    target_schema_name = metadata["target_table"].get("target_schema","") 
    target_table_name = metadata["target_table"].get("target_table","")

    target_ignore_env = metadata["target_table"].get("params",{}).get("ignore_env",False) 
    target_ignore_project = metadata["target_table"].get("params",{}).get("ignore_project",False) 
    target_drop_table = metadata["target_table"].get("params",{}).get("drop_table",False) 

    target_linked_service = metadata["target_table"].get("target_linked_service","") 
    target_path = metadata["target_table"].get("target_path","") 
    target_container = metadata["target_table"].get("target_container","") 

    opertation_type = metadata["target_table"].get("opertation_type",None)
    incremental_column_name =  metadata.get("incremental_column_name","LOAD_DATE")
    
    try:

        if opertation_type == "Linked Service Type":

            mount_scope = "workspace"
            mount_point_name = get_mount_name(target_linked_service, target_container, mount_scope, env)
            mount_point[mount_point_name], mount_synfs_path[mount_point_name] = mount_from_linkedservice(target_linked_service, target_container, mount_scope, env)
            target_file_synfs_path = f"{mount_synfs_path[mount_point_name]}{target_path}/{target_table_name}"    
            format_type = metadata["target_table"].get("format","delta")
            
            if format_type == 'parquet':
                dataframe = spark.read.parquet(target_file_synfs_path)
            elif format_type == 'delta':
                dataframe = spark.read.format("delta").load(target_file_synfs_path)
            elif format_type == 'csv':
                csv_options = metadata["target_table"].get("csv_options",None)
                dataframe = spark.read.options(**csv_options).csv(target_file_synfs_path)
            elif format_type == 'xlsx':
                xlsx_options = metadata["target_table"].get("xlsx_options",None)
                dataframe = spark.read.format("com.crealytics.spark.excel").options(**xlsx_options).load(f"{target_file_synfs_path}.xlsx")
            synapse_target_table_name = "target_view"
            dataframe.createOrReplaceTempView(synapse_target_table_name)

        else: 
            synapse_target_table_name = get_table_env_name(target_schema_name, target_table_name, env, project_name, target_ignore_env, target_ignore_project)



        max_value_from_table = spark.sql(f"""SELECT MAX(`{incremental_column_name}`) FROM {synapse_target_table_name}""").collect()[0][0]
        print(f"Max value of {incremental_column_name} columns is: {max_value_from_table}")
        logger.info(f"Max value of {incremental_column_name} columns is: {max_value_from_table}")
    except:
        print(f"No max value found for {incremental_column_name} columns")
        logger.info(f"No max value found for {incremental_column_name} columns")
        max_value_from_table = None


    return max_value_from_table
#####################################################################

def get_log_data(metadata_json, processing_date_to_log):

    """
    Generates log data for each source table defined in the provided metadata JSON.

    Parameters:
    metadata_json (str): A JSON string containing metadata information. 

    Returns:
    list: A list of Spark Row objects, each containing logging information for a source table.

    Function Logic:
    1. Parse the metadata JSON to extract necessary information.
    2. Generate the fully qualified target table name using the `get_table_env_name` function.
    3. Iterate over each source table in the metadata and determine the source table's fully qualified name based on the operation type.
    4. Create a Spark Row object for each source table with relevant logging information and add it to the rows list.
    5. Return the list of Row objects.
    """


    #Metadata store
    metadata = json.loads(metadata_json)

    env = None if metadata["env"]=="None" else metadata["env"]
    project_name = metadata["project_name"]

    target_schema_name = metadata["target_table"].get("target_schema","") 
    target_table_name = metadata["target_table"].get("target_table","")

    target_opertation_type = metadata["target_table"].get("opertation_type",None)

    if target_opertation_type == "Linked Service Type":
        target_linked_service = metadata["target_table"].get("target_linked_service","") 
        target_path = metadata["target_table"].get("target_path","") 
        target_container = metadata["target_table"].get("target_container","") 
        synapse_target_table_name = f"{target_container}/{target_linked_service}/{target_path}/"
    else:
        synapse_target_table_name = get_table_env_name(target_schema_name, target_table_name, env, project_name, None, None)

    rows = []
    
    for source_table in metadata["source_tables"]:

        source_opertation_type = source_table.get("opertation_type",None)
        source_opertation_type = 'lake database view' if not source_opertation_type else source_opertation_type
        load_type = source_table.get("load_type",source_opertation_type)
    
        linked_service_name = source_table.get("linked_service_name",None)

        #Get evnironment
        env = None if metadata["env"]=="None" else metadata["env"]
        #Get project name
        project_name = metadata["project_name"]
        param = source_table.get("params",None) 

        current_date = datetime.now().strftime("%Y-%m-%d")
        process_date_str = metadata.get("process_date", current_date)
        if process_date_str == 'None':
            process_date = datetime.strptime(current_date, '%Y-%m-%d')
        else:
            try:
                process_date = datetime.strptime(process_date_str, '%Y-%m-%d')
            except ValueError:
                try:
                    process_date = datetime.strptime(process_date_str, '%Y-%m-%d %H:%M:%S')
                except ValueError:
                    # Jako ostatnią opcję, spróbuj wyodrębnić tylko datę
                    try:
                        date_only = process_date_str.split()[0]
                        process_date = datetime.strptime(date_only, '%Y-%m-%d')
                    except ValueError:
                        raise ValueError(f"Not supported date format: {process_date_str}")
                        
        ignore_env = None
        ignore_project = None
        if param:
            ignore_env =  param.get("ignore_env",False)
            ignore_project =  param.get("ignore_project",False) 
    
        synapse_source_table_name = None
        if source_opertation_type == "Linked Service":
            linked_service_name = source_table.get("linked_service_name",None)
            container = source_table.get("container",None)
            path = source_table.get("path",None)
            synapse_source_table_name = f"{container}/{linked_service_name}/{path}/"
        elif source_opertation_type == "serverless jdbc ls":
            table_name = source_table.get("table",None)
            synapse_source_table_name = get_table_env_name(target_schema_name, target_table_name, env, project_name, ignore_env, ignore_project)
        elif source_opertation_type == "serverless jdbc ls sqlauth":
            table_name = source_table.get("table",None)
            synapse_source_table_name = get_table_env_name(target_schema_name, target_table_name, env, project_name, ignore_env, ignore_project)
        else:
            source_schema_name = source_table.get("schema",None)
            source_table_name = source_table.get("table",None)
            synapse_source_table_name = get_table_env_name(source_schema_name, source_table_name, env, project_name, ignore_env, ignore_project)


        if isinstance(processing_date_to_log, str):
            try:
                modify_datetime = datetime.strptime(processing_date_to_log, '%Y-%m-%d %H:%M:%S')
            except ValueError:
                try:          
                    modify_datetime = datetime.strptime(processing_date_to_log, '%Y-%m-%d')
                except ValueError:
                    raise ValueError(f"Not supported date format: {processing_date_to_log}")
        else:
            modify_datetime = processing_date_to_log

        row = Row(  
                project_name = project_name, 
                load_type = load_type,
                synapse_source_table_name = synapse_source_table_name,
                synapse_target_table_name = synapse_target_table_name,
                params = load_type,
                modify_datetime = modify_datetime,
                processing_date = process_date
                )
     
        rows.append(row)
    return rows



#####################################################################

def repartition(finalTableName, sort_order = None, partitionBy = []):

    """
    Repartitions a Delta table based on the optimal partition size.

    Parameters:
    finalTableName (str): The name of the table to repartition.
    sort_order (str, optional): Column names to sort by during optimization. Defaults to None.
    partitionBy (list, optional): List of column names to partition by. Defaults to an empty list.

    Function Logic:
    1. Calculate the optimal number of partitions based on the size of the table.
    2. If the current number of partitions is outside the optimal range, repartition the table.
    3. Create a temporary table with the new partitions.
    4. Rename the original table to a backup name and the temporary table to the final table name.
    5. Drop the backup table.
    6. Perform VACUUM on the final table to clean up old files.
    """

    finalTableName_tmp = finalTableName+"_tmp"
    finalTableName_bcp = finalTableName+"_bcp"

    details = spark.sql(f"DESCRIBE DETAIL {finalTableName}").collect()[0]
    size_in_bytes = details['sizeInBytes']
    size_in_mb = size_in_bytes / (1024 * 1024)

    optimal_partition_size_mb = 60 ## Optimal 128 but with 128 it could create 130 MB size files
    num_partitions = math.ceil(size_in_mb / optimal_partition_size_mb)

    print(f"Size of table {finalTableName} is {size_in_mb} MB")
    print(f"Number of partitions will be created: {num_partitions}")

    df = spark.table(finalTableName)
    current_num_partitions = df.rdd.getNumPartitions()

    lower_bound = num_partitions * 0.95
    upper_bound = num_partitions * 1.05

    print(f"Current number of partitions: {current_num_partitions}")
    print(f"Optimal number of partitions: {num_partitions}")

    if not (lower_bound <= current_num_partitions <= upper_bound):

        try:
            print(f"Repartition table {finalTableName}")
            df = spark.table(finalTableName)
            df.repartition(num_partitions).write.option("parquet.vorder.enabled ","true").format("delta").mode("overwrite").partitionBy(*partitionBy).saveAsTable(finalTableName_tmp)
            spark.sql(f"ALTER TABLE {finalTableName} RENAME TO {finalTableName_bcp}")
            spark.sql(f"ALTER TABLE {finalTableName_tmp} RENAME TO {finalTableName}")
            spark.sql(f"DROP TABLE {finalTableName_bcp}")
    
        except Exception as e:
            print(f"An error occurred: {e}")

        print(f"VACUUM table {finalTableName}")
        spark.sql(f"VACUUM {finalTableName}")

    else:

        print(f"No need for table repartitioning: {finalTableName}")


    #if sort_order:
    #    print(f"OPTIMIZE table {finalTableName}")
    #    spark.sql(f"OPTIMIZE {finalTableName} ZORDER BY ({sort_order}) VORDER")

#####################################################################11

def log_to_datalineage(metadata_json, admin_schema_name, data_lineage_table_name):


    """
    Logs data lineage information into a specified Delta table.

    Parameters:
    metadata_json (str): A JSON string containing metadata information. 
    admin_schema_name (str): The schema name for the administrative tables.
    data_lineage_table_name (str): The name of the table where data lineage information will be logged.

    Function Logic:
    1. Parse the metadata JSON to extract necessary information.
    2. Generate the fully qualified target table name based on the environment and project context.
    3. Iterate over each source table in the metadata and determine the source table's fully qualified name based on the operation type.
    4. Create a list of tuples containing the data lineage information.
    5. Create a DataFrame from the list of tuples and write it to a temporary view.
    6. Merge the new data lineage information into the specified Delta table with retry logic for handling concurrent write conflicts.
    """


    max_retries = 3
    #Read json metadata
    metadata = json.loads(metadata_json)

    data = []
    from datetime import date
    today = date.today().strftime('%Y-%m-%d')

    logger.info(f"Log data lineage")
    print(f"Log data lineage")

    notebook_name = mssparkutils.runtime.context['currentNotebookName']

    env = None if metadata["env"]=="None" else metadata["env"]
    env_for_table = "prod" if metadata["env"]=="None" else metadata["env"]
    project_name = metadata["project_name"]

    

    #Target table
    target_schema_name = project_name+'_'+metadata["target_table"].get("target_schema","") 
    target_table_name = metadata["target_table"].get("target_table","")
    target_opertation_type = metadata["target_table"].get("opertation_type",None)


    if target_opertation_type == "Linked Service Type":
        target_linked_service = metadata["target_table"].get("target_linked_service","None") 
        target_path = metadata["target_table"].get("target_path","None") 
        target_container = metadata["target_table"].get("target_container","None") 
        finalTableName = f"{target_linked_service}/{target_container}/{target_path}"
    else:
        if env:
            target_schema_name = target_schema_name +"_"+ env
        finalTableName = f"{target_schema_name}.{target_table_name}"

    if env:
        admin_schema_name = admin_schema_name +"_"+ env
    final_data_lineage_table_name = f"{admin_schema_name}.{project_name}_{data_lineage_table_name}"

    for source_table in metadata['source_tables']:

        schema = source_table.get("source",None)
        table = source_table.get("table",None)
        source_opertation_type = source_table.get("opertation_type","lake database view")
        source_params = ""

        if source_opertation_type == "Linked Service":
            source_linked_service_name = source_table.get("linked_service_name","None")
            source_container = source_table.get("container","None")
            source_path = source_table.get("path","None")
            finalSourceTableName = f"{source_linked_service_name}/{source_container}/{source_path}"

        elif source_opertation_type == "serverless jdbc ls":
            finalSourceTableName = source_table.get("table",None)
            source_linked_service_name = source_table.get("linked_service_name","None")

        elif source_opertation_type == "serverless jdbc ls sqlauth":
            finalSourceTableName = source_table.get("table",None)
            source_linked_service_name = source_table.get("linked_service_name","None")

        elif source_opertation_type == "lake database view":
            source_schema_name = project_name+'_'+source_table.get("schema",None)
            source_table_name = source_table.get("table",None)
            # Check if env is not None and ignore_env is False
            ignore_env = source_table.get("params", {}).get("ignore_env", False)
            if env and not ignore_env:
                source_schema_name = source_schema_name +"_"+ env
            finalSourceTableName = f"{source_schema_name}.{source_table_name}"

        elif source_opertation_type == "lake database view":
            source_schema_name = project_name+'_'+source_table.get("schema",None)
            source_table_name = source_table.get("table",None)
            # Check if env is not None and ignore_env is False
            ignore_env = source_table.get("params", {}).get("ignore_env", False)
            if env and not ignore_env:
                source_schema_name = source_schema_name +"_"+ env
            finalSourceTableName = f"{source_schema_name}.{source_table_name}"

        elif source_opertation_type == "delta sharing":
            source_table_name = source_table.get("table",None) 
            source_schema_name = source_table.get("profile",None)  
            source_linked_service_name = source_table.get("key_vaule_ls",None)
            ignore_env = source_table.get("params", {}).get("ignore_env", False)
            if env and not ignore_env:
                source_schema_name = source_schema_name +"_"+ env
            finalSourceTableName = f"{source_schema_name}.{source_table_name}"

        elif source_opertation_type == "delta sharing 1.2.0":
            source_table_name = source_table.get("table",None) 
            source_schema_name = source_table.get("profile",None)  
            source_linked_service_name = source_table.get("key_vaule_ls",None)
            ignore_env = source_table.get("params", {}).get("ignore_env", False)
            if env and not ignore_env:
                source_schema_name = source_schema_name +"_"+ env
            finalSourceTableName = f"{source_schema_name}.{source_table_name}"


        data.append((notebook_name, 
                    env_for_table, 
                    finalTableName, 
                    json.dumps(metadata['target_table']), 
                    finalSourceTableName, 
                    json.dumps(source_table), 
                    today))

        #Add additional row for "serverless jdbc ls"
        if source_opertation_type == "serverless jdbc ls" or source_opertation_type == "serverless jdbc ls sqlauth":

            data.append((notebook_name, 
                    env_for_table, 
                    finalSourceTableName, 
                    "",
                    source_linked_service_name, 
                    source_params,
                    today))

    schema = StructType([   
        StructField("notebook_name", StringType(), False),
        StructField("env", StringType(), True),
        StructField("target_table_name", StringType(), False),
        StructField("target_params", StringType(), True),
        StructField("source_table_name", StringType(), False),
        StructField("source_params", StringType(), True),
        StructField("_ModifiedDate", StringType(), False)
    ])    

    df = spark.createDataFrame(data, schema)
    df.createOrReplaceTempView("new_data_lineage_data")

    print(f"Data lineage table: {final_data_lineage_table_name}")
    logger.info(f"Data lineage table: {final_data_lineage_table_name}")


    #Write processing info to metadata column
    for attempt in range(max_retries):
        try:
            spark.sql(f"""
                MERGE INTO {final_data_lineage_table_name} AS target
                USING new_data_lineage_data AS source
                ON target.target_table_name = source.target_table_name
                AND target.source_table_name = source.source_table_name
                AND target.target_table_name = '{finalTableName}'
                WHEN MATCHED THEN
                UPDATE SET
                    target.notebook_name = source.notebook_name,
                    target.env = source.env,
                    target.target_params = source.target_params,
                    target.source_params = source.source_params,
                    target._ModifiedDate = source._ModifiedDate
                WHEN NOT MATCHED THEN
                INSERT (notebook_name, env, target_table_name, target_params, source_table_name, source_params, _ModifiedDate)
                VALUES (source.notebook_name, source.env, source.target_table_name, source.target_params, source.source_table_name, source.source_params, source._ModifiedDate)
            """)
            break
        except ConcurrentAppendException as e:
            print(f"Write conflict detected on attempt {attempt+1}. Error: {e}. Retrying...")
            time.sleep(2 ** attempt)  # Exponential backoff
        except Exception as e:
            print(f"An error occurred: {e}")
            logger.info(f"An error occurred: {e}")
            break

## Custom rules

In [1]:
def get_packrule_by_columns(static_rules_temp_view_name, packrule_FUNCTION, packrule_CRITERIA, KEY_COLUMNS_LIST: list, VALUE_COLUMN_LIST: list, FN_COLUMN_LIST: list, OVERRIDE_COLUMN):
   
    """
    Generates a SQL case statement based on the given parameters.

    Args:
    packrule_function (str): The function to apply.
    packrule_criteria (str): Criteria for filtering.
    key_columns (list): Key columns for filtering.
    value_columns (list): Corresponding value columns.
    fn_columns (list): Function columns.
    override_column (str): Column name for override values.

    Returns:
    str: A SQL case statement.
    """

    # Declare tables to read base on ENV 
    # VRP_RAW_NAME, VRP_PROCESS_NAME, VRP_NORMALIZED_NAME, VRP_OUTPUT_NAME


    #KEY_COLUMNS_LIST - list of columns used to filter data
    #VALUE_COLUMN_LIST - list of columns that will be equal to KEY columns
    #OVERRIDE_COLUMN - column with override value

    #allowed columns
    allowed_KEY_COLUMNS_list    = ['KEY1','KEY2','KEY3','KEY4','KEY5','KEY6']
    allowed_VALUE_COLUMN_list   = ['VALUE1','VALUE2','VALUE3','VALUE4','VALUE5','VALUE6']
    allowed_FN_COLUMN_list      = ['FN1','FN2','FN3','FN4','FN5','FN6']
    allowed_RESULT_COLUMN_list  = ['RESULT1','RESULT2','RESULT3','RESULT4','RESULT5','RESULT6']
    #Map RESULT to RESULT TYPE column
    column_mapping_RESULT1_TYPE = {
        "RESULT1": "RESULT1_TYPE",
        "RESULT2": "RESULT2_TYPE",
        "RESULT3": "RESULT3_TYPE",
        "RESULT4": "RESULT4_TYPE",
        "RESULT5": "RESULT5_TYPE",
        "RESULT6": "RESULT6_TYPE"
    }

    #Test parameters
    if not isinstance(KEY_COLUMNS_LIST, list):
        raise TypeError
    if not isinstance(VALUE_COLUMN_LIST, list):
        raise TypeError
    if not isinstance(FN_COLUMN_LIST, list):
        raise TypeError

    if not len(KEY_COLUMNS_LIST) == len(VALUE_COLUMN_LIST):
        raise Exception("Number of KEY columns does not equal to number of VALUE columns")

    if not all(item in allowed_KEY_COLUMNS_list for item in KEY_COLUMNS_LIST):
        raise Exception("Not allowed KEY column name. Allowed: ", ', '.join(allowed_KEY_COLUMNS_list))

    if not all(item in allowed_VALUE_COLUMN_list for item in VALUE_COLUMN_LIST):
        raise Exception("Not allowed VALUE column name. Allowed: ", ', '.join(allowed_VALUE_COLUMN_list))    
    
    if not all(item in allowed_FN_COLUMN_list for item in FN_COLUMN_LIST):
        raise Exception("Not allowed FN column name. Allowed: ", ', '.join(allowed_FN_COLUMN_list))    

    if not OVERRIDE_COLUMN in allowed_RESULT_COLUMN_list:
        raise Exception("Not allowed RESULT column name. Allowed: ", ', '.join(allowed_RESULT_COLUMN_list))    

    RESULT_TYPE = column_mapping_RESULT1_TYPE.get(OVERRIDE_COLUMN, None)


    columns_to_select = ', '.join(KEY_COLUMNS_LIST)+', '+', '.join(VALUE_COLUMN_LIST)+', '+', '.join(FN_COLUMN_LIST)+', '+OVERRIDE_COLUMN+', '+RESULT_TYPE

    #If criteria is provided then filder packrule table based on CRITERIA value
    if packrule_CRITERIA is None: 
        filter_CRITERIA_query = ""
    else:
        filter_CRITERIA_query = f"""AND CRITERIA = "{packrule_CRITERIA}" """

    packrules_cnt = spark.sql(f"""SELECT SEQUENCE, {columns_to_select} FROM {static_rules_temp_view_name} WHERE FUNCTION = "{packrule_FUNCTION}" {filter_CRITERIA_query} ORDER BY SEQUENCE ASC """)

    pd_packrules_df = packrules_cnt.toPandas()

    #Initate case statement
    case_query = "CASE "

    #Default end of case statement
    case_query_end = "ELSE NULL END "

    #create case statement
    for ind, r in pd_packrules_df.iterrows():

        case = ""
        if_else_condition = False

        item_no = 0
        
        for key in KEY_COLUMNS_LIST:
            if r[1+item_no] is not None:
                if r[1+item_no+len(KEY_COLUMNS_LIST)] == "[Any/Other]":
                    case = case
                elif r[1+item_no] == "[Else]":
                    ##if delcared [Else] condition
                    if r[RESULT_TYPE]=="[column]":
                        case_query_end = f"""ELSE `{r[OVERRIDE_COLUMN]}`  END """
                    elif r[RESULT_TYPE]=="[value]":
                        case_query_end = f"""ELSE "{r[OVERRIDE_COLUMN]}"  END """
                    else:
                        case_query_end = f"""ELSE "{r[OVERRIDE_COLUMN]}"  END """
                    if_else_condition = True
                else:
                    if r[RESULT_TYPE]=="[column]":
                        ##if result type is columns then use ``
                        case += f""" AND `{r[1+item_no]}` {r[1+item_no+(2*len(KEY_COLUMNS_LIST))]} `{r[1+item_no+len(KEY_COLUMNS_LIST)]}` """ 
                    elif r[RESULT_TYPE]=="[value]":
                        ##if result type is columns then use ""
                        case += f""" AND `{r[1+item_no]}` {r[1+item_no+(2*len(KEY_COLUMNS_LIST))]} "{r[1+item_no+len(KEY_COLUMNS_LIST)]}" """ 
                    else:
                        ##else use default if not declared ""
                        case += f""" AND `{r[1+item_no]}` {r[1+item_no+(2*len(KEY_COLUMNS_LIST))]} "{r[1+item_no+len(KEY_COLUMNS_LIST)]}" """ 
                item_no += 1

        
        #Creation of case query
        if not if_else_condition:
            case_query += f"""WHEN {case[5:-1]} THEN "{r[OVERRIDE_COLUMN]}" 
            """

        
    case_query += case_query_end
    return case_query


#Example
#get_packrule_by_columns("static_ods_rules","MDG_MARS_SEGMENT_AND_DIVISION", "MARS_SEGMENT_DIVISION", ['KEY1','KEY2'], ['VALUE1','VALUE2'], ['FN1', 'FN2'], 'RESULT1')

In [None]:
def get_ods_rule(rules_temp_view_name, packrule_FUNCTION, packrule_CRITERIA, KEY_COLUMNS_LIST: list, VALUE_COLUMN_LIST: list, FN_COLUMN_LIST: list, OVERRIDE_COLUMN):
    """
    Generates a SQL CASE statement for the new_country column based on the given parameters.

    Args:
        rules_temp_view_name (str): The name of the table or view containing the rules.
        packrule_FUNCTION (str): The function to filter the rules by (e.g., "COUNTRY_OF_USE_OVERRIDE").
        packrule_CRITERIA (str): The criteria to filter the rules by (e.g., "BY_DRIVER").
        KEY_COLUMNS_LIST (list): List of key columns used to filter data.
        VALUE_COLUMN_LIST (list): List of value columns that correspond to the key columns.
        FN_COLUMN_LIST (list): List of function columns that define the operation (e.g., "=").
        OVERRIDE_COLUMN (str): The column that contains the result value / overriden column.

    Returns:
        str: A SQL CASE statement.
    """

    #allowed columns
    allowed_KEY_COLUMNS_list    = ['KEY1','KEY2','KEY3','KEY4','KEY5','KEY6']
    allowed_VALUE_COLUMN_list   = ['VALUE1','VALUE2','VALUE3','VALUE4','VALUE5','VALUE6']
    allowed_FN_COLUMN_list      = ['FN1','FN2','FN3','FN4','FN5','FN6']
    allowed_RESULT_COLUMN_list  = ['RESULT1','RESULT2','RESULT3','RESULT4','RESULT5','RESULT6']
    #Map RESULT to RESULT TYPE column
    column_mapping_RESULT_TYPE1 = {
        "RESULT1": "RESULTTYPE1",
        "RESULT2": "RESULTTYPE2",
        "RESULT3": "RESULTTYPE3",
        "RESULT4": "RESULTTYPE4",
        "RESULT5": "RESULTTYPE5",
        "RESULT6": "RESULTTYPE6"
    }

    # Validate input parameters
    if not isinstance(KEY_COLUMNS_LIST, list):
        raise TypeError("KEY_COLUMNS_LIST must be a list")
    if not isinstance(VALUE_COLUMN_LIST, list):
        raise TypeError("VALUE_COLUMN_LIST must be a list")
    if not isinstance(FN_COLUMN_LIST, list):
        raise TypeError("FN_COLUMN_LIST must be a list")

    if not len(KEY_COLUMNS_LIST) == len(VALUE_COLUMN_LIST):
        raise Exception("Number of KEY columns does not equal to number of VALUE columns")

    if not all(item in allowed_KEY_COLUMNS_list for item in KEY_COLUMNS_LIST):
        raise Exception("Not allowed KEY column name. Allowed: ", ', '.join(allowed_KEY_COLUMNS_list))

    if not all(item in allowed_VALUE_COLUMN_list for item in VALUE_COLUMN_LIST):
        raise Exception("Not allowed VALUE column name. Allowed: ", ', '.join(allowed_VALUE_COLUMN_list))    
    
    if not all(item in allowed_FN_COLUMN_list for item in FN_COLUMN_LIST):
        raise Exception("Not allowed FN column name. Allowed: ", ', '.join(allowed_FN_COLUMN_list))    

    if not OVERRIDE_COLUMN in allowed_RESULT_COLUMN_list:
        raise Exception("Not allowed RESULT column name. Allowed: ", ', '.join(allowed_RESULT_COLUMN_list))    
    
    RESULT_TYPE = column_mapping_RESULT_TYPE1.get(OVERRIDE_COLUMN, None)


    # Construct the SQL query to fetch the rules
    columns_to_select = ', '.join(KEY_COLUMNS_LIST) + ', ' + ', '.join(VALUE_COLUMN_LIST) + ', ' + ', '.join(FN_COLUMN_LIST) + ', ' + OVERRIDE_COLUMN + ', ' + RESULT_TYPE


    #If criteria is provided then filter packrule table based on CRITERIA value
    filter_CRITERIA_query = f"""AND CRITERIA = "{packrule_CRITERIA}" """ if packrule_CRITERIA else ""

    # Fetch the rules
    rules_df = spark.sql(f"""
        SELECT {columns_to_select}
        FROM {rules_temp_view_name}
        WHERE FUNCTION = "{packrule_FUNCTION}" {filter_CRITERIA_query}
        -- ORDER BY SEQUENCE ASC
    """)

    pd_rules_df = rules_df.toPandas()

    # Initialize the CASE statement
    case_query = "\nCASE"

    #End of case statement
    case_query_end = "\n    ELSE NULL\nEND\n"

    # Iterate over the rules and construct the CASE statement
    for _, row in pd_rules_df.iterrows():
        conditions = []
        for i, key in enumerate(KEY_COLUMNS_LIST):
            value = row[VALUE_COLUMN_LIST[i]]
            fn = row[FN_COLUMN_LIST[i]]
            if value is not None:
                if fn.strip().upper() == "IN":
                    # For IN operator, don't put quotes around the value
                    conditions.append(f"""`{row[key]}` {fn} {value} """)
                else:
                    conditions.append(f"""`{row[key]}` {fn} "{value}" """)
            else:
                conditions.append(f"""`{row[key]}` IS NULL """)

        if conditions:
            if row[RESULT_TYPE] == "[column]":
                case_query += f"""\n    WHEN {' AND '.join(conditions)} THEN `{row[OVERRIDE_COLUMN]}` """
            elif row[RESULT_TYPE] == "[value]":
                case_query += f"""\n    WHEN {' AND '.join(conditions)} THEN "{row[OVERRIDE_COLUMN]}" """
            else:
                case_query += f"""\n    WHEN {' AND '.join(conditions)} THEN {row[OVERRIDE_COLUMN]} """

    # Add the ELSE clause
    case_query += case_query_end

    return case_query


# Examples:

# case_statement_segment_override = get_ods_rule(
#                                                         rules_temp_view_name="ods_rules",
#                                                         packrule_FUNCTION="MGIS_SEGMENT_OVERRIDE",
#                                                         packrule_CRITERIA="OVERRIDE",
#                                                         KEY_COLUMNS_LIST=["KEY1", "KEY2"],
#                                                         VALUE_COLUMN_LIST=["VALUE1", "VALUE2"],
#                                                         FN_COLUMN_LIST=["FN1", "FN2"],
#                                                         OVERRIDE_COLUMN="RESULT1"
#                                                 )

# case_statement_division_override = get_ods_rule(
#                                                         rules_temp_view_name="ods_rules",
#                                                         packrule_FUNCTION="RUSSIA_DIVISION_OVERRIDE",
#                                                         packrule_CRITERIA="BY_DIVISION",
#                                                         KEY_COLUMNS_LIST=["KEY1", "KEY2"],
#                                                         VALUE_COLUMN_LIST=["VALUE1", "VALUE2"],
#                                                         FN_COLUMN_LIST=["FN1", "FN2"],
#                                                         OVERRIDE_COLUMN="RESULT1"
#                                                 )

# case_statement_new_country_override = get_ods_rule(
#                                                         rules_temp_view_name="ods_rules",
#                                                         packrule_FUNCTION="COUNTRY_OF_USE_OVERRIDE",
#                                                         packrule_CRITERIA="BY_DRIVER",
#                                                         KEY_COLUMNS_LIST=["KEY1", "KEY2"],
#                                                         VALUE_COLUMN_LIST=["VALUE1", "VALUE2"],
#                                                         FN_COLUMN_LIST=["FN1", "FN2"],
#                                                         OVERRIDE_COLUMN="RESULT1"
#                                                     )

# case_statement_new_country_override_1 = get_ods_rule(
#                                                         rules_temp_view_name="ods_rules",
#                                                         packrule_FUNCTION="COUNTRIES_TO_SUBSEGMENTS_MGIS_ADJ",
#                                                         packrule_CRITERIA="BY_COUNTRY_OF_USE",
#                                                         KEY_COLUMNS_LIST=["KEY1"],
#                                                         VALUE_COLUMN_LIST=["VALUE1"],
#                                                         FN_COLUMN_LIST=["FN1"],
#                                                         OVERRIDE_COLUMN="RESULT1"
#                                                     )

## send_notification

In [None]:
def send_notification(message, subject, header, status, color, recipients, env, 
                     kv_secret_name_with_url="ODS-notification-url", 
                     lv_linked_service="SOLUTION_KEY_VAULT_LS"):
    """
    Sends a notification using a REST API endpoint with credentials stored in Azure Key Vault.
    
    Parameters:
    -----------
    message : str
        The main body of the notification message
    subject : str
        The subject line of the notification
    header : str
        The header text for the notification
    status : str
        The status indicator for the notification (e.g., 'success', 'error', 'warning', 'info')
    color : str
        The color theme for the notification (e.g., 'blue', 'red', 'green', 'yellow')
    recipients : str
        Email address(es) of the notification recipient(s)
    env : str
        Environment identifier (e.g., 'dev', 'test', 'prod')
    kv_secret_name_with_url : str, optional
        The name of the secret in Key Vault containing the notification service URL
        (default: "ODS-notification-url")
    lv_linked_service : str, optional
        The name of the linked service for Key Vault access
        (default: "SOLUTION_KEY_VAULT_LS")
    
    Returns:
    --------
    bool
        True if the notification was sent successfully (status code 200 or 202)
        False if the request failed
    
    Raises:
    -------
    Exception
        If there's an error accessing Key Vault or sending the notification
    
    Example:
    --------
    >>> result = send_notification(
    ...     message="Pipeline completed successfully",
    ...     subject="Pipeline Status Update",
    ...     header="Pipeline Notification",
    ...     status="success",
    ...     color="green",
    ...     recipients="team@company.com",
    ...     env="prod"
    ... )
    >>> print("Notification sent:", result)
    
    Notes:
    ------
    - The function uses mssparkutils.credentials to access Key Vault
    - Success is indicated by HTTP status codes 200 or 202
    - Full response details are logged for debugging purposes
    """
    
    try:
        # Get URL from Key Vault
        url = mssparkutils.credentials.getSecretWithLS(lv_linked_service, kv_secret_name_with_url)
        
        # Prepare data
        data = {
            "message": message,
            "subject": subject,
            "header": header,
            "status": status,
            "color": color,
            "recipients": recipients,
            "env": env
        }
        
        # Send request
        response = requests.post(url, json=data)
        
        # Print response details for debugging
        print("Status Code:", response.status_code)
        print("Response Content:", response.text)
        print("Response Headers:", dict(response.headers))
        
        if response.status_code == 200:
            print("Notification sent successfully")
            return True
                
        elif response.status_code == 202:
            print("Request accepted")
            return True
            
        else:
            print(f"Request failed with status code: {response.status_code}")
            print(f"Response content: {response.text}")
            return False
            
    except Exception as e:
        print(f"An error occurred: {str(e)}")
        raise

## repartition_with_salt

In [None]:
def repartition_with_salt(df: DataFrame, target_rows_per_partition: int = 100000) -> DataFrame:
    """
    Rebalances a DataFrame by adding a salt column and repartitioning the data.
    
    :param df: Input DataFrame
    :param target_rows_per_partition: Approximate target number of rows per partition (default: 100,000)
    :return: Repartitioned DataFrame with balanced partitions
    """
    total_rows = df.count()

    num_partitions = max(1, total_rows // target_rows_per_partition)

    print(f"Repartition: {num_partitions}")

    df = df.withColumn("salt", (F.rand() * num_partitions).cast("int"))

    df = df.repartition(num_partitions, "salt")

    df = df.drop("salt")

    return df

## create_table_from_dataframe

In [None]:

def create_table_from_dataframe(env, schema_name, table_name, dataframe, type = 'parquet', partitionBy = []):
    ## Formulating the target table name
    if env:
        schema_name = schema_name +"_"+ env
    finalTableName = f"{schema_name}.{table_name}"

    ## Dropping the table
    spark.sql(
        f"DROP TABLE IF EXISTS {finalTableName}"
    )  # drop every time for now until everything will be checked (to avoid mergeschema)

    print(f"Create {type} table: {finalTableName}")
    if type == 'parquet':
        # Save as parquet
        dataframe.write.mode("overwrite").partitionBy(*partitionBy).saveAsTable(finalTableName)
    elif type == 'delta':
        # Save as delta
        dataframe.write.format("delta").mode("overwrite").partitionBy(*partitionBy).saveAsTable(finalTableName)
    ## Validating if table is populated
    ## This would result to an error; halting the cell
    assert spark.sql(f"SELECT COUNT(*) as total_rows from {finalTableName}").collect()[0]['total_rows'] >= 0


## create_table_from_schema

In [None]:
def create_table_from_schema(env, schema_name, table_name, schema: StructType(), partitionBy):

    if env:
        _schema_name = schema_name +"_"+ env
        table_full_name = f"{_schema_name}.{table_name}"
    else:
        _schema_name = schema_name
        table_full_name = f"{_schema_name}.{table_name}"
        
    # Check if result table exists
    if not delta_table_exists(table_full_name):
        emp_RDD = spark.sparkContext.emptyRDD()
        empty_agg_df = spark.createDataFrame(emp_RDD, schema=schema)

        # Create the table with the defined schema if it does not exist using an empty RDD
        create_table_from_dataframe(env, schema_name, table_name, empty_agg_df, "delta", partitionBy)

# additional methods

## list_files_recursive

In [7]:
# Get a list of all files in the data folder and its subfolders, including additional file information
# Now supports filtering by file extension
def list_files_recursive(folder_path, extension='parquet', file_name_like = None):
    all_files = []
    try:
        for file_info in mssparkutils.fs.ls(folder_path):
            if file_info.isDir:
                # Recursively list files in subfolder, including additional file information
                all_files.extend(list_files_recursive(file_info.path, extension, file_name_like))
            else:
                # Check if file extension matches the specified extension, if any
                if (extension is None or file_info.path.endswith(f".{extension}")) and (file_name_like is None or file_name_like in file_info.path):
                    # Create a dictionary with additional file information
                    file_details = {
                        'path': file_info.path,
                        'name': file_info.name,
                        'modifyTime': time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(file_info.modifyTime / 1000))
                    }
                    all_files.append(file_details)
    except:
        print(f" Error at {folder_path}")
    return all_files

## export_to_synfs_ADLS

In [8]:
def export_to_synfs_ADLS(metadata_json, current_timestamp, current_yyyymmdd, delete_export_files = False):

    # Export data from one ADLS location defined by Linked Service to another location defined by Linked Service and rename files
    #Read json metadata
    metadata = json.loads(metadata_json)

    #Get evnironment
    env = None if metadata["env"]=="None" else metadata["env"]
    project_name = metadata["project_name"]


    logger.info(f"ENV: {env}")
    print(f"ENV: {env}")
    logger.info(f"PROJECT_NAME: {project_name}")
    print(f"PROJECT_NAME: {project_name}")

    mount_scope = "workspace"

    export_to_linked_service = metadata["export_to"].get("linked_service",None) 
    export_to_container = metadata["export_to"].get("container",None)
    export_to_path = metadata["export_to"].get("path",None)
    export_to_extension = metadata["export_to"].get("extension","csv")
    rename_file_as = metadata["export_to"].get("rename_file_as","")
    
    if not export_to_linked_service or not export_to_container or not export_to_path:
        raise ValueError("export to: Required parameters path, container, linked_service not declared")

    export_from_linked_service = metadata["export_from"].get("linked_service",None) 
    export_from_container = metadata["export_from"].get("container",None)
    export_from_path = metadata["export_from"].get("path",None)
    export_from_extension = metadata["export_from"].get("extension","csv")

    backup_export_enabled =  metadata["backup_export_to"].get("enabled","False")
    backup_export_path =  metadata["backup_export_to"].get("path",None)
    backup_export_container =  metadata["backup_export_to"].get("container",None)
    backup_export_linked_service =  metadata["backup_export_to"].get("linked_service",None)
    backup_export_folder_timestamp = metadata["backup_export_to"].get("folder_timestamp","00000000_000000")
    backup_export_folder = metadata["backup_export_to"].get("folder","None")
    
    print(f"\n    Backup is:  {backup_export_enabled}")
    
    if not export_from_linked_service or not export_from_container or not export_from_path:
        raise ValueError("export from: Required parameters path, container, linked_service not declared")


    #mount_point declared in "Utils/data_processing_utilis_v2"
    export_to_mount_point_name = get_mount_name(export_to_linked_service, export_to_container, mount_scope, env)

    mount_point[export_to_mount_point_name], mount_synfs_path[export_to_mount_point_name] = mount_from_linkedservice(export_to_linked_service, export_to_container, mount_scope, env)

    export_from_mount_point_name = get_mount_name(export_from_linked_service, export_from_container, mount_scope, env)

    mount_point[export_from_mount_point_name], mount_synfs_path[export_from_mount_point_name] = mount_from_linkedservice(export_from_linked_service, export_from_container, mount_scope, env)

    export_to_synfs_path = f"{mount_synfs_path[export_to_mount_point_name]}{export_to_path}/"
    export_from_synfs_path = f"{mount_synfs_path[export_from_mount_point_name]}{export_from_path}/"

    print(f"Export to: {export_to_synfs_path}")
    print(f"Export from: {export_from_synfs_path}")
    
    file_list_to_export = list_files_recursive(export_from_synfs_path, export_from_extension)


    if backup_export_enabled == "True":
        print(f"\n    Backup files:  {backup_export_enabled}")
        backup_export_mount_point_name = get_mount_name(backup_export_linked_service, backup_export_container, mount_scope, env)
        mount_point[backup_export_mount_point_name], mount_synfs_path[backup_export_mount_point_name] = mount_from_linkedservice(backup_export_linked_service, backup_export_container, mount_scope, env)
        backup_export_synfs_path = f"{mount_synfs_path[backup_export_mount_point_name]}{backup_export_path}"

        print(f"backup_export_synfs_path: {backup_export_synfs_path}")

        backup_export_synfs_file_path = backup_export_synfs_path+backup_export_folder_timestamp+"/"+backup_export_folder+"/"
        
        for file in file_list_to_export:
            
            match = re.search(r"=([^/]+)/", file['path'])

            if match:
                value = match.group(1)
                final_target_path = f"{backup_export_synfs_file_path}{rename_file_as}_{value}_{current_timestamp}.{export_to_extension}"
            else:
                final_target_path = f"{backup_export_synfs_file_path}{rename_file_as}_{current_timestamp}.{export_to_extension}"

            print(f"Export: {file['path']} to {final_target_path}")
            
            mssparkutils.fs.cp(file['path'], final_target_path, True)


    if delete_export_files:
        try:
            delete_files_in_folder_with_linked_service(export_to_linked_service, export_to_container, export_target_path_lorax, ENV, False)
        except:
            print("\n    Folder does not exists")

    for file in file_list_to_export:
        
        match = re.search(r"=([^/]+)/", file['path'])

        if match:
            value = match.group(1)
            final_target_path = f"{export_to_synfs_path}{rename_file_as}_{value}_{current_timestamp}.{export_to_extension}"
        else:
            final_target_path = f"{export_to_synfs_path}{rename_file_as}_{current_timestamp}.{export_to_extension}"

        print(f"Export: {file['path']} to {final_target_path}")
        
        mssparkutils.fs.cp(file['path'], final_target_path, True)


########## USAGE ###########

### ------------------------------------------------------------------------------------
#export_metadata_json = '''
#{
#    "export_to":   {"path" :"GLOBAL_XSEG_CORPORATE_ODS/EXPORTS/LORAX/LORAX_VENDORS",  "rename_file_as":"lorax_vendors",  "container": "output" , "linked_service" : "MARS_ANALYTICS_EXPORT_ADLS_LS"  , "file_extension": "csv"},
#    "backup_export_to":   { "enabled":"'''+str(publish_success_file)+'''",  "path" :"'''+str(backup_export_path)+'''", "folder":"'''+str(output_table_name)+'''", "folder_timestamp": "'''+str(folder_timestamp)+'''", "container": "output" , "linked_service" : "MARS_ANALYTICS_EXPORT_ADLS_LS"},
#    "export_from": {"path":"LORAX/DATA_TO_EXPORT/LORAX_VENDORS",                        "container": "files" ,  "linked_service" : "SOLUTION_ADLS_LS"               , "file_extension": "csv"},
#    "env": "'''+str(ENV)+'''",
#    "project_name": "'''+str(PROJECT_NAME)+'''"
#}
#'''
#
#export_to_synfs_ADLS(export_metadata_json, current_timestamp, current_yyyymmdd, True)

############################

## create_incremental_table


```
┌─────────────────────────────────────────────────────────────────┐
│                    START DELTA LOAD                             │
└─────────────────────────────────────────────────────────────────┘
                              │
                              ▼
┌──────────────────────────────────────────────────────────────────┐
│  SCHEMA EVOLUTION CHECK                                          │
│  ─────────────────────                                           │
│  source_columns = get_table_columns(source)     → [A, B, C, NEW] │
│  target_columns = get_table_columns(target)     → [A, B, C]      │
│  new_columns = source - target                  → [NEW]          │
└──────────────────────────────────────────────────────────────────┘
                              │
              ┌───────────────┴───────────────┐
              │                               │
              ▼                               ▼
┌─────────────────────────┐     ┌─────────────────────────────────┐
│ ignore=TRUE             │     │ ignore=FALSE                    │
│ ────────────────────    │     │ ────────────────────            │
│ 1. ALTER TABLE ADD NEW  │     │ 1. ALTER TABLE ADD NEW          │
│ 2. MERGE UPDATE NEW     │     │ 2. (no update - stays NULL)     │
│    from source          │     │                                 │
│                         │     │                                 │
│ columns_for_comparison  │     │ columns_for_comparison          │
│   = [A, B, C] (OLD)     │     │   = [A, B, C, NEW] (ALL)        │
└─────────────────────────┘     └─────────────────────────────────┘
              │                               │
              ▼                               ▼
┌─────────────────────────┐     ┌─────────────────────────────────┐
│ HASH COMPARISON         │     │ HASH COMPARISON                 │
│ ─────────────────────   │     │ ─────────────────────           │
│ source: hash(A,B,C)     │     │ source: hash(A,B,C,NEW)         │
│ target: hash(A,B,C)     │     │ target: hash(A,B,C,NULL)        │
│                         │     │                                 │
│ Result: EQUAL           │     │ Result: NOT EQUAL               │
│ → No change detected    │     │ → All records = Update          │
└─────────────────────────┘     └─────────────────────────────────┘
              │                               │
              ▼                               ▼
┌─────────────────────────┐     ┌─────────────────────────────────┐
│ OUTPUT                  │     │ OUTPUT                          │
│ ─────────────────────   │     │ ─────────────────────           │
│ total_count = 0         │     │ total_count = N (all rows)      │
│ "No updates to process" │     │ operation_type = 'U'            │
│ last_update_dt UNCHANGED│     │ last_update_dt = NOW            │
└─────────────────────────┘     └─────────────────────────────────┘


In [None]:
from pyspark.sql.functions import col, lit, current_timestamp, to_date
from delta.tables import DeltaTable
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any, Tuple, Set
import json

# Technical columns added by the incremental load process (not from source)
TECHNICAL_COLUMNS: Set[str] = {'operation_type', 'last_update_dt'}


def get_table_columns(table_name: str, exclude_technical: bool = False) -> List[str]:
    """
    Retrieves list of columns from a table, excluding partitioning metadata.
    
    Args:
        table_name: Fully qualified table name
        exclude_technical: If True, excludes operation_type and last_update_dt
    
    Returns:
        List of column names
    """
    describe_df = spark.sql(f"DESCRIBE {table_name}")
    columns = []
    
    for row in describe_df.collect():
        col_name = row.col_name
        
        # Stop at partitioning section
        if col_name and col_name.startswith("#"):
            break
        
        # Skip empty or invalid entries
        if not col_name or col_name.strip() == "":
            continue
        
        # Optionally exclude technical columns
        if exclude_technical and col_name in TECHNICAL_COLUMNS:
            continue
        
        columns.append(col_name)
    
    return columns


def get_column_types(table_name: str) -> Dict[str, str]:
    """
    Retrieves dictionary of {column_name: data_type} for a table.
    
    Args:
        table_name: Fully qualified table name
    
    Returns:
        Dictionary mapping column names to their data types
    """
    describe_df = spark.sql(f"DESCRIBE {table_name}")
    column_types = {}
    
    for row in describe_df.collect():
        col_name = row.col_name
        data_type = row.data_type
        
        # Stop at partitioning section
        if col_name and col_name.startswith("#"):
            break
        
        if col_name and col_name.strip() != "" and data_type:
            column_types[col_name] = data_type
    
    return column_types


def detect_new_columns(
    source_table: str, 
    target_table: str
) -> Tuple[List[str], Dict[str, str]]:
    """
    Detects new columns in source that don't exist in target.
    Compares only business columns (excludes technical columns from target).
    
    Args:
        source_table: Source table name
        target_table: Target table name
    
    Returns:
        Tuple of (list of new column names, dict of {column: data_type})
    """
    source_columns = set(get_table_columns(source_table))
    # Exclude technical columns when comparing - they are added by this process
    target_business_columns = set(get_table_columns(target_table, exclude_technical=True))
    
    new_columns = list(source_columns - target_business_columns)
    
    if new_columns:
        source_types = get_column_types(source_table)
        new_column_types = {c: source_types[c] for c in new_columns if c in source_types}
    else:
        new_column_types = {}
    
    return new_columns, new_column_types


def add_columns_to_delta_table(
    table_name: str, 
    columns_with_types: Dict[str, str]
) -> List[str]:
    """
    Adds new columns to an existing Delta table using ALTER TABLE.
    
    Args:
        table_name: Fully qualified table name
        columns_with_types: Dictionary of {column_name: data_type}
    
    Returns:
        List of successfully added column names
    """
    added_columns = []
    
    if not columns_with_types:
        return added_columns
    
    for col_name, col_type in columns_with_types.items():
        try:
            spark.sql(f"ALTER TABLE {table_name} ADD COLUMNS (`{col_name}` {col_type})")
            print(f"    [+] Added column '{col_name}' ({col_type})")
            added_columns.append(col_name)
        except Exception as e:
            error_msg = str(e).lower()
            if "already exists" in error_msg or "found duplicate" in error_msg:
                print(f"    [=] Column '{col_name}' already exists, skipping")
            else:
                raise e
    
    return added_columns


def update_new_columns_from_source(
    target_table: str,
    source_table: str,
    id_key_column: str,
    new_columns: List[str]
) -> None:
    """
    Updates values of new columns in target table from source using MERGE.
    Only updates existing records (matched by id_key_column).
    Does NOT update last_update_dt or operation_type.
    
    Args:
        target_table: Target Delta table name
        source_table: Source table name
        id_key_column: Column name used for joining
        new_columns: List of new column names to update
    """
    if not new_columns:
        return
    
    delta_table = DeltaTable.forName(spark, target_table)
    source_df = spark.table(source_table)
    
    # Build SET clause: {target_column: source_column_expression}
    # Using col() function for proper column reference
    update_set = {c: col(f"source.{c}") for c in new_columns}
    
    print(f"    Executing MERGE UPDATE for columns: {new_columns}")
    
    (
        delta_table.alias("target")
        .merge(
            source_df.alias("source"),
            f"target.`{id_key_column}` = source.`{id_key_column}`"
        )
        .whenMatchedUpdate(set=update_set)
        .execute()
    )
    
    print(f"    [✓] Updated {len(new_columns)} column(s) with values from source")


def handle_schema_evolution(
    source_table_name: str,
    target_table_name: str,
    history_table_name: Optional[str],
    id_key_column: str,
    ignore_new_columns_as_change: bool
) -> Tuple[List[str], List[str]]:
    """
    Handles schema evolution by detecting and adding new columns.
    
    CRITICAL LOGIC:
    - source_columns: All business columns from source (for writing)
    - columns_for_comparison: Columns used for hash comparison (change detection)
    
    When ignore_new_columns_as_change=True:
        1. Add new columns to target with ALTER TABLE
        2. UPDATE target with values from source (MERGE)
        3. Hash comparison uses OLD columns only → no change detected
    
    When ignore_new_columns_as_change=False:
        1. Add new columns to target as NULL (ALTER TABLE only)
        2. No UPDATE
        3. Hash comparison uses ALL columns → change detected (NULL vs value)
    
    Args:
        source_table_name: Source table name
        target_table_name: Target table name  
        history_table_name: History table name (optional)
        id_key_column: Primary key column
        ignore_new_columns_as_change: Schema evolution mode
    
    Returns:
        Tuple of (source_columns, columns_for_comparison)
    """
    print("\n" + "="*60)
    print("SCHEMA EVOLUTION CHECK")
    print("="*60)
    
    # Get business columns (exclude technical columns from target)
    source_columns = get_table_columns(source_table_name)
    target_business_columns = get_table_columns(target_table_name, exclude_technical=True)
    
    print(f"Source columns count: {len(source_columns)}")
    print(f"Target business columns count: {len(target_business_columns)}")
    
    # Detect new columns
    new_columns, new_column_types = detect_new_columns(source_table_name, target_table_name)
    
    if not new_columns:
        print("\n[✓] No new columns detected - schemas are aligned")
        print("="*60 + "\n")
        return source_columns, source_columns
    
    print(f"\n[!] Detected {len(new_columns)} NEW COLUMN(S): {new_columns}")
    print(f"    Mode: ignore_new_columns_as_change = {ignore_new_columns_as_change}")
    
    if ignore_new_columns_as_change:
        # ============================================================
        # MODE: TRUE - Add columns WITH values, no change detection
        # ============================================================
        print("\n--- Processing: Add columns with actual values ---")
        
        # Step 1: Add columns to target table structure
        print(f"\n  Step 1: Adding columns to TARGET ({target_table_name})")
        add_columns_to_delta_table(target_table_name, new_column_types)
        
        # Step 2: Update values from source using MERGE
        print(f"\n  Step 2: Populating new columns with source values")
        update_new_columns_from_source(
            target_table=target_table_name,
            source_table=source_table_name,
            id_key_column=id_key_column,
            new_columns=new_columns
        )
        
        # Step 3: Add columns to history table (values will be NULL for historical records)
        if history_table_name and delta_table_exists(history_table_name):
            print(f"\n  Step 3: Adding columns to HISTORY ({history_table_name})")
            add_columns_to_delta_table(history_table_name, new_column_types)
            print("    Note: Historical records will have NULL for new columns")
        
        # For comparison, use only OLD columns (pre-evolution)
        # This ensures no false change detection due to new columns
        columns_for_comparison = target_business_columns
        
        print(f"\n  Result:")
        print(f"    - New columns added and populated: {new_columns}")
        print(f"    - Hash comparison will use {len(columns_for_comparison)} OLD columns")
        print(f"    - Changes in existing data WILL be detected")
        print(f"    - New columns alone will NOT trigger updates")
        
    else:
        # ============================================================
        # MODE: FALSE - Add columns as NULL, detect as change
        # ============================================================
        print("\n--- Processing: Add columns as NULL (will detect as change) ---")
        
        # Step 1: Add columns to target (will be NULL)
        print(f"\n  Step 1: Adding columns to TARGET as NULL ({target_table_name})")
        add_columns_to_delta_table(target_table_name, new_column_types)
        
        # Step 2: Add columns to history (will be NULL)
        if history_table_name and delta_table_exists(history_table_name):
            print(f"\n  Step 2: Adding columns to HISTORY as NULL ({history_table_name})")
            add_columns_to_delta_table(history_table_name, new_column_types)
        
        # For comparison, use ALL columns including new ones
        # NULL (target) vs value (source) = change detected
        columns_for_comparison = source_columns
        
        print(f"\n  Result:")
        print(f"    - New columns added as NULL: {new_columns}")
        print(f"    - Hash comparison will use ALL {len(columns_for_comparison)} columns")
        print(f"    - All records WILL be marked as 'Update'")
    
    print("="*60 + "\n")
    
    return source_columns, columns_for_comparison


def create_incremental_table(
    id_key_column: str,
    target_table_name: str,
    source_table_name: str,
    metadata_full_load: str,
    metadata_delta_load: str,
    create_table_params: Optional[Dict[str, Any]] = None,
    included_columns_for_hash: Optional[List[str]] = None,
    excluded_columns_for_hash: Optional[List[str]] = None,
    log_history: bool = False,
    history_table_name: Optional[str] = None,
    history_retention_days: Optional[int] = None,
    ignore_new_columns_as_change: bool = True
) -> Optional[DataFrame]:
    """
    Creates or updates an incremental Delta table with schema evolution support.
    
    CHANGE DETECTION LOGIC:
    - Insertions: Records in source not in target (LEFT ANTI JOIN)
    - Deletions: Records in target not in source (marked as 'D')
    - Updates (tracked): Changes in specified columns → new last_update_dt
    - Updates (untracked): Changes in other columns → preserve last_update_dt
    - Reactivations: Previously deleted records reappearing in source
    
    SCHEMA EVOLUTION:
    - New columns in source are detected automatically
    - ignore_new_columns_as_change=True: Add with values, no change trigger
    - ignore_new_columns_as_change=False: Add as NULL, triggers Update
    
    Args:
        id_key_column: Primary key column name for matching records
        target_table_name: Fully qualified target table name
        source_table_name: Fully qualified source table name
        metadata_full_load: JSON metadata for full load operation
        metadata_delta_load: JSON metadata for delta load operation
        create_table_params: Additional parameters for create_table()
        included_columns_for_hash: Columns to include in change detection hash
        excluded_columns_for_hash: Columns to exclude from change detection hash
        log_history: Whether to log changes to history table
        history_table_name: Name of history/audit table
        history_retention_days: Days to retain history records
        ignore_new_columns_as_change: Schema evolution mode
            True (default): New columns added silently with values
            False: New columns treated as changes, all records updated
    
    Returns:
        DataFrame with changes applied, or None if no changes detected
    """
    create_table_params = create_table_params or {}
    
    # ================================================================
    # FULL LOAD - Target table does not exist
    # ================================================================
    if not delta_table_exists(target_table_name):
        logger.info("Executing full load logic")
        print("\n" + "="*60)
        print("FULL LOAD - Creating new target table")
        print("="*60)
        
        make_env_tables(metadata_full_load)
        
        df = spark.sql("""
            SELECT  *,
                    'I' AS operation_type,
                    current_timestamp() AS last_update_dt
            FROM    source 
        """)
        
        record_count = df.count()
        print(f"Inserting {record_count} records with operation_type='I'")
        
        create_table(metadata_full_load, df, skip_data_lineage=True)
        
        print(f"[✓] Full load completed: {target_table_name}")
        print("="*60 + "\n")
        
        return df
    
    # ================================================================
    # DELTA LOAD - Target table exists, process changes
    # ================================================================
    logger.info("Executing delta logic")
    print("\n" + "="*60)
    print("DELTA LOAD - Processing incremental changes")
    print("="*60)
    
    # ----------------------------------------------------------------
    # STEP 1: Handle schema evolution BEFORE creating views
    # ----------------------------------------------------------------
    # This ensures target table has new columns before we read it
    source_columns, columns_for_comparison = handle_schema_evolution(
        source_table_name=source_table_name,
        target_table_name=target_table_name,
        history_table_name=history_table_name if log_history else None,
        id_key_column=id_key_column,
        ignore_new_columns_as_change=ignore_new_columns_as_change
    )
    
    # Columns to write (all source columns)
    table_columns = source_columns
    
    # ----------------------------------------------------------------
    # STEP 2: Determine hash columns for change detection
    # ----------------------------------------------------------------
    # Start with columns_for_comparison (may exclude new columns)
    # Then apply user-specified include/exclude filters
    
    if included_columns_for_hash and excluded_columns_for_hash:
        # Both specified: include only specified, then exclude
        table_columns_for_hash = [
            c for c in included_columns_for_hash
            if c not in excluded_columns_for_hash 
            and c in columns_for_comparison
        ]
    elif included_columns_for_hash:
        # Only include specified
        table_columns_for_hash = [
            c for c in included_columns_for_hash
            if c in columns_for_comparison
        ]
    elif excluded_columns_for_hash:
        # Exclude specified from comparison set
        table_columns_for_hash = [
            c for c in columns_for_comparison
            if c not in excluded_columns_for_hash
        ]
    else:
        # Default: use all comparison columns
        table_columns_for_hash = list(columns_for_comparison)
    
    print(f"Hash configuration:")
    print(f"  - Columns for tracked changes hash: {len(table_columns_for_hash)}")
    print(f"  - Columns for all-fields hash: {len(columns_for_comparison)}")
    print(f"  - Total columns to write: {len(table_columns)}")
    
    # ----------------------------------------------------------------
    # STEP 3: Create temp views (AFTER schema evolution)
    # ----------------------------------------------------------------
    # Now target view will include new columns with populated values
    make_env_tables(metadata_delta_load)
    
    source = spark.table("source")
    target = spark.table("target")
    
    now = datetime.now()
    
    print(f"\nProcessing changes at: {now}")
    print("-"*40)
    
    # ----------------------------------------------------------------
    # STEP 4: Detect INSERTIONS (new records in source)
    # ----------------------------------------------------------------
    df_new_insertions = (
        source.alias("s")
        .join(
            target.alias("t"),
            on=col(f"s.{id_key_column}") == col(f"t.{id_key_column}"),
            how="left_anti"
        )
        .select(*[col(c) for c in table_columns])
        .withColumn("operation_type", lit("I"))
        .withColumn("last_update_dt", lit(now))
        .withColumn("update_type", lit("Insert"))
    )
    insertions_count = df_new_insertions.count()
    print(f"  Insertions (new records):           {insertions_count:>8}")
    
    # ----------------------------------------------------------------
    # STEP 5: Detect DELETIONS (records removed from source)
    # ----------------------------------------------------------------
    # Note: We select from target, so columns exist
    # Only mark active records (I or U) as deleted
    df_new_deletions = (
        target.alias("t")
        .join(
            source.alias("s"),
            on=col(f"s.{id_key_column}") == col(f"t.{id_key_column}"),
            how="left_anti"
        )
        .filter(col("t.operation_type").isin("I", "U"))
        .select(*[col(f"t.{c}") for c in table_columns])
        .withColumn("operation_type", lit("D"))
        .withColumn("last_update_dt", lit(now))
        .withColumn("update_type", lit("Delete"))
    )
    deletions_count = df_new_deletions.count()
    print(f"  Deletions (removed from source):    {deletions_count:>8}")
    
    # ----------------------------------------------------------------
    # STEP 6: Create hashed DataFrames for change detection
    # ----------------------------------------------------------------
    # CRITICAL: Hash only on columns_for_comparison (excludes new columns if ignore=True)
    
    df_hashed_source = (
        source
        .withColumn("_hash_tracked", spark_hash(*table_columns_for_hash))
        .withColumn("_hash_all", spark_hash(*columns_for_comparison))
    )
    
    df_hashed_target = (
        target
        .withColumn("_hash_tracked", spark_hash(*table_columns_for_hash))
        .withColumn("_hash_all", spark_hash(*columns_for_comparison))
    )
    
    # ----------------------------------------------------------------
    # STEP 7: Detect UPDATES in tracked columns
    # ----------------------------------------------------------------
    # These get new last_update_dt and operation_type='U'
    df_new_updates_tracked = (
        df_hashed_source.alias("s")
        .join(
            df_hashed_target.alias("t"),
            on=col(f"s.{id_key_column}") == col(f"t.{id_key_column}"),
            how="inner"
        )
        .filter(
            (col("s._hash_tracked") != col("t._hash_tracked")) &
            (col("t.operation_type") != "D")
        )
        .select(
            *[col(f"s.{c}") for c in table_columns],
            lit("U").alias("operation_type"),
            lit(now).alias("last_update_dt"),
            lit("Update-tracked").alias("update_type")
        )
    )
    updates_tracked_count = df_new_updates_tracked.count()
    print(f"  Updates (tracked columns):          {updates_tracked_count:>8}")
    
    # ----------------------------------------------------------------
    # STEP 8: Detect UPDATES in untracked columns
    # ----------------------------------------------------------------
    # These preserve original last_update_dt and operation_type
    df_new_updates_untracked = (
        df_hashed_source.alias("s")
        .join(
            df_hashed_target.alias("t"),
            on=col(f"s.{id_key_column}") == col(f"t.{id_key_column}"),
            how="inner"
        )
        .filter(
            (col("s._hash_tracked") == col("t._hash_tracked")) &
            (col("s._hash_all") != col("t._hash_all")) &
            (col("t.operation_type") != "D")
        )
        .select(
            *[col(f"s.{c}") for c in table_columns],
            col("t.operation_type"),
            col("t.last_update_dt"),
            lit("Update-untracked").alias("update_type")
        )
    )
    updates_untracked_count = df_new_updates_untracked.count()
    print(f"  Updates (untracked columns):        {updates_untracked_count:>8}")
    
    # ----------------------------------------------------------------
    # STEP 9: Detect REACTIVATIONS (previously deleted, now back)
    # ----------------------------------------------------------------
    df_reactivations = (
        df_hashed_source.alias("s")
        .join(
            df_hashed_target.alias("t"),
            on=col(f"s.{id_key_column}") == col(f"t.{id_key_column}"),
            how="inner"
        )
        .filter(col("t.operation_type") == "D")
        .select(
            *[col(f"s.{c}") for c in table_columns],
            lit("U").alias("operation_type"),
            lit(now).alias("last_update_dt"),
            lit("Reactivate").alias("update_type")
        )
    )
    reactivations_count = df_reactivations.count()
    print(f"  Reactivations (un-deleted):         {reactivations_count:>8}")
    
    # ----------------------------------------------------------------
    # STEP 10: Calculate total and process
    # ----------------------------------------------------------------
    total_count = (
        insertions_count + 
        deletions_count + 
        updates_tracked_count + 
        updates_untracked_count + 
        reactivations_count
    )
    print("-"*40)
    print(f"  TOTAL CHANGES:                      {total_count:>8}")
    
    if total_count == 0:
        print("\n[✓] Source and Target are aligned. No updates to process!")
        print("="*60 + "\n")
        return None
    
    # ----------------------------------------------------------------
    # STEP 11: Union all changes
    # ----------------------------------------------------------------
    df_all_changes = (
        df_new_insertions
        .union(df_new_deletions)
        .union(df_new_updates_tracked)
        .union(df_new_updates_untracked)
        .union(df_reactivations)
    )
    
    # Remove internal update_type column for target table
    df_for_target = df_all_changes.drop("update_type")
    
    # ----------------------------------------------------------------
    # STEP 12: Log to history table (optional)
    # ----------------------------------------------------------------
    if log_history and history_table_name:
        print(f"\nHistory logging enabled → {history_table_name}")
        
        df_for_history = df_all_changes.withColumn("log_datetime", lit(now))
        
        if delta_table_exists(history_table_name):
            print(f"  Appending {total_count} records to existing history table")
            (
                df_for_history.write
                .format("delta")
                .mode("append")
                .option("mergeSchema", "true")  # Handle new columns
                .saveAsTable(history_table_name)
            )
        else:
            print(f"  Creating new history table with {total_count} records")
            (
                df_for_history.write
                .format("delta")
                .mode("overwrite")
                .option("overwriteSchema", "true")
                .saveAsTable(history_table_name)
            )
        
        # Apply retention policy
        if history_retention_days and history_retention_days > 0:
            cutoff_date = (datetime.now() - timedelta(days=history_retention_days)).date()
            cutoff_str = cutoff_date.strftime('%Y-%m-%d')
            
            delta_history = DeltaTable.forName(spark, history_table_name)
            old_records = (
                delta_history.toDF()
                .filter(to_date("log_datetime") < lit(cutoff_str))
                .count()
            )
            
            if old_records > 0:
                print(f"  Retention: Deleting {old_records} records older than {cutoff_str}")
                delta_history.delete(f"to_date(log_datetime) < '{cutoff_str}'")
    
    # ----------------------------------------------------------------
    # STEP 13: Write changes to target table
    # ----------------------------------------------------------------
    print(f"\nWriting {total_count} changes to target table...")
    create_table(metadata_delta_load, df_for_target, **create_table_params)
    
    print(f"\n[✓] Delta load completed successfully")
    print("="*60 + "\n")
    
    return df_all_changes

# Declare variables

## Mount points 

In [9]:
# print("Declare 'mount_point' variable (dict) for synfs paths to ADLS")

from collections import defaultdict
# GLOBAL VARIABLES

#Job id
job_id = mssparkutils.env.getJobId()

#Mount points
mount_point = defaultdict(dict)
mount_synfs_path = defaultdict(dict) 

## Process date

In [None]:
from datetime import  timedelta
if not process_date:
    today = date.today()
    process_date = today - timedelta(days=2)
    process_date = process_date.strftime('%Y-%m-%d')
    print(f"Process date has not been declared, new process date is: {process_date}")

## Workspace env

In [None]:
workspace_name = mssparkutils.env.getWorkspaceName()
workspace_env = re.findall(r"(dev|prod|uat)(?=syn)", workspace_name)

# Get total cores
executor_count = spark.sparkContext._jsc.sc().getExecutorMemoryStatus().size() - 1
cores_per_executor = int(spark.conf.get("spark.executor.cores"))
total_cores = executor_count * cores_per_executor

# Partitioning and repartitioning optiomization
optimal_parallelism = total_cores * 2
#spark.conf.set("spark.default.parallelism", str(optimal_parallelism))
#spark.conf.set("spark.sql.shuffle.partitions", str(optimal_parallelism))
BASE_PARTITION_COUNT = optimal_parallelism


#print(f"BASE_PARTITION_COUNT (for table partitioning and spark.default.parallelism / spark.sql.shuffle.partitions): {BASE_PARTITION_COUNT}" )

## Alter delta sharing protocol 

In [7]:
try:
    
    import delta_sharing
    import json
    from delta_sharing.rest_client import DeltaSharingProfile

    @staticmethod
    def custom_read_from_file(profile: str) -> "DeltaSharingProfile":
        # Splitting the profile parameter into key_vault_ls and secret_name
        try:
            key_vault_ls, secret_name = profile.split(";")
        except ValueError:
            raise ValueError("Invalid format for profile. Expected format: 'key_vault_ls;secret_name'")

        print(profile)

        # Retrieving the secret from Azure Key Vault
        secret_content = mssparkutils.credentials.getSecretWithLS(key_vault_ls, secret_name)

        # Parsing the secret content into a DeltaSharingProfile object
        try:
            return DeltaSharingProfile.from_json(secret_content)
        except Exception as e:
            raise ValueError(f"Failed to parse DeltaSharingProfile: {e}")

    DeltaSharingProfile.read_from_file = custom_read_from_file 
    print(f"    Delta sharing imported, version: {delta_sharing.__version__}")
except ImportError:
    print("    No delta sharing module")


# *** Usage ***
# table_url = "SOLUTION_KEY_VAULT_LS;DeltaSharingProfile#corporate_finance_analytics_sustainability.finsight_core_model.dimensions_item_taxonomy"
## SOLUTION_KEY_VAULT_LS - key vault linked service name
## DeltaSharingProfile secret with delta sharing profile (token, endpoint)
# df = delta_sharing.load_as_pandas(table_url)

## delete_folder_with_linked_service

In [None]:
def delete_folder_with_linked_service(Linked_Service_Name, container, path, ENV, recursive = True):

    # Define the mounting scope and retrieve the mount point name
    mount_scope = "workspace"  # Scope of the mount operation
    mount_point_name = get_mount_name(Linked_Service_Name, container, mount_scope, ENV)

    # Mount the linked service storage container and retrieve paths
    mount_point[mount_point_name], mount_synfs_path[mount_point_name] = mount_from_linkedservice(
        Linked_Service_Name, container, mount_scope, ENV
    )

    # Define the full path to the success file
    file_synfs_path = f"{mount_synfs_path[mount_point_name]}{path}"

    # Delete any existing success file to ensure a clean state
    print(f"Deleting data from: {file_synfs_path}")
    try:
        mssparkutils.fs.rm(file_synfs_path, recursive)  # Remove the file if it exists
    except:
        print("Folder does not exist")  # Handle case where the file is not found


### delete_files_in_folder_with_linked_service

In [None]:
def delete_files_in_folder_with_linked_service(Linked_Service_Name, container, path, ENV, recursive = False):

    # Define the mounting scope and retrieve the mount point name
    mount_scope = "workspace"  # Scope of the mount operation
    mount_point_name = get_mount_name(Linked_Service_Name, container, mount_scope, ENV)

    # Mount the linked service storage container and retrieve paths
    mount_point[mount_point_name], mount_synfs_path[mount_point_name] = mount_from_linkedservice(
        Linked_Service_Name, container, mount_scope, ENV
    )

    # Define the full path to the success file
    file_synfs_path = f"{mount_synfs_path[mount_point_name]}{path}"
    files = list_files_recursive(file_synfs_path, extension = None, file_name_like = None)

    for file in files:
        file_path = file['path']
        # Delete any existing success file to ensure a clean state
        print(f"Deleting data from: {file_path}")
        try:
            mssparkutils.fs.rm(file_path, recursive)  # Remove the file if it exists
        except:
            print(f"File {file_path} does not exist")  # Handle case where the file is not found


### audit_whitespace_usage

In [None]:
# HOOGLPE1 20250625 : audit report + 'normalize' and/or 'trim' function to replace all whitespace-like but not-space chars with space ' ' \u0020, and trim leading/trailing blanks if requested

from pyspark.sql.functions import col, trim, concat_ws, lit, udf
from pyspark.sql.types import StringType, ArrayType, StructType, StructField
import unicodedata
import pandas as pd

# === Default whitespace-like characters ===
DEFAULT_WHITESPACE_MAP = {
    '\u0009': '<TAB>',
    '\u000A': '<LF>',
    '\u000B': '<VT>',
    '\u000C': '<FF>',
    '\u000D': '<CR>',
    '\u0020': '<SPACE>',
    '\u00A0': '<NBSP>',
    '\u1680': '<OGHAM_SPACE>',
    '\u180E': '<MONGOLIAN_VOWEL_SEP>',
    '\u2000': '<EN_QUAD>',
    '\u2001': '<EM_QUAD>',
    '\u2002': '<EN_SPACE>',
    '\u2003': '<EM_SPACE>',
    '\u2004': '<THREE_PER_EM_SPACE>',
    '\u2005': '<FOUR_PER_EM_SPACE>',
    '\u2006': '<SIX_PER_EM_SPACE>',
    '\u2007': '<FIGURE_SPACE>',
    '\u2008': '<PUNCTUATION_SPACE>',
    '\u2009': '<THIN_SPACE>',
    '\u200A': '<HAIR_SPACE>',
    '\u2028': '<LINE_SEPARATOR>',
    '\u2029': '<PARAGRAPH_SEPARATOR>',
    '\u202F': '<NARROW_NO_BREAK_SPACE>',
    '\u205F': '<MEDIUM_MATH_SPACE>',
    '\u3000': '<IDEOGRAPHIC_SPACE>',
    '\uFEFF': '<ZERO_WIDTH_NBSP>'
}

# === Audit Function ===
def audit_whitespace_usage(df, key_columns, columns_to_process=None, whitespace_map=None):
    if not isinstance(key_columns, list):
        raise ValueError("key_columns must be a list")

    # Always start from the full known list
    full_map = DEFAULT_WHITESPACE_MAP

    # Use the passed map to filter out chars the user has already handled
    excluded_chars = set(whitespace_map.keys()) if whitespace_map else set()
    audit_chars = {k: v for k, v in full_map.items() if k not in excluded_chars}

    broadcasted_whitespace_set = set(audit_chars.keys())
    broadcasted_visual_map = dict(audit_chars)

    @udf(returnType=StringType())
    def visualize_whitespace(val):
        if val is None:
            return None
        return ''.join([broadcasted_visual_map.get(ch, ch) for ch in val])

    @udf(returnType=StringType())
    def needs_trim(val):
        if val is None or not isinstance(val, str):
            return 'N'
        return 'Y' if val != val.strip() else 'N'

    @udf(returnType=StringType())
    def contains_nonspace_whitespace(val):
        if val is None or not isinstance(val, str):
            return 'N'
        for ch in val:
            if ch in broadcasted_whitespace_set and ch != '\u0020':
                return 'Y'
        return 'N'

    if not columns_to_process:
        columns_to_process = [c for c in df.columns if c not in key_columns]

    audit_rows = []
    for colname in columns_to_process:
        temp = df.withColumn("original_value", col(colname)) \
                 .withColumn("visualized_value", visualize_whitespace(col(colname))) \
                 .withColumn("needs_trimming", needs_trim(col(colname))) \
                 .withColumn("has_nonspace_whitespace", contains_nonspace_whitespace(col(colname))) \
                 .withColumn("column_name", lit(colname))

        for key in key_columns:
            temp = temp.withColumn(key, col(key).cast("string"))

        temp = temp.withColumn("keyfield_id", concat_ws("|", *[col(k) for k in key_columns]))

        result = temp.select(
            "keyfield_id", "column_name", "original_value", "visualized_value",
            "needs_trimming", "has_nonspace_whitespace"
        ).filter((col("needs_trimming") == "Y") | (col("has_nonspace_whitespace") == "Y"))

        audit_rows.append(result)

    if audit_rows:
        from functools import reduce
        return reduce(lambda a, b: a.unionByName(b), audit_rows)
    else:
        return spark.createDataFrame([], schema=StructType([
            StructField("keyfield_id", StringType()),
            StructField("column_name", StringType()),
            StructField("original_value", StringType()),
            StructField("visualized_value", StringType()),
            StructField("needs_trimming", StringType()),
            StructField("has_nonspace_whitespace", StringType()),
        ]))

def normalize_whitespace_fields(
    df,
    columns_to_process=None,
    trim=True,
    normalize=True,
    whitespace_map=None,
):
    if not whitespace_map:
        whitespace_map = DEFAULT_WHITESPACE_MAP

    # Validate and determine columns to process
    if columns_to_process is None or columns_to_process == []:
        columns_to_process = [
            f.name for f in df.schema.fields
            if isinstance(f.dataType, StringType)
        ]
    elif not isinstance(columns_to_process, list):
        raise TypeError("columns_to_process must be a list of strings or None.")
    elif not all(isinstance(c, str) and c.strip() for c in columns_to_process):
        raise ValueError("columns_to_process must be a list of non-empty strings.")

    whitespace_chars = list(whitespace_map.keys())
    universal_space = "\u0020"

    def clean_value(val):
        if val is None or not isinstance(val, str):
            return val
        result = val
        if normalize:
            for ch in whitespace_chars:
                result = result.replace(ch, universal_space)
        if trim:
            result = result.strip()
        return result

    clean_value_udf = udf(clean_value, StringType())

    cleaned_df = df
    for colname in columns_to_process:
        if colname not in df.columns:
            continue
        cleaned_df = cleaned_df.withColumn(colname, clean_value_udf(colname))

    return cleaned_df


### full_compare_2_dfs

In [None]:
from pyspark.sql.functions import lit, col, concat_ws

def full_compare_2_dfs(
    df1,
    df2,
    key_fields,
    id_label1="df1",
    id_label2="df2"
):

    def _schema_diffs(a, b):
        a_cols, b_cols = set(a.columns), set(b.columns)
        only_in_a = sorted(a_cols - b_cols)
        only_in_b = sorted(b_cols - a_cols)
        # check type mismatches for common columns
        a_types = {f.name: f.dataType.simpleString() for f in a.schema.fields}
        b_types = {f.name: f.dataType.simpleString() for f in b.schema.fields}
        common = a_cols & b_cols
        type_mismatch = sorted(
            [(c, a_types[c], b_types[c]) for c in common if a_types[c] != b_types[c]]
        )
        return only_in_a, only_in_b, type_mismatch

    if isinstance(key_fields, str):
        key_fields = [key_fields]

    # Validate that schemas match (names & types). Give a precise diff if not.
    only_in_df1, only_in_df2, type_mismatch = _schema_diffs(df1, df2)
    if only_in_df1 or only_in_df2 or type_mismatch:
        msgs = []
        if only_in_df1:
            msgs.append(f"Only in {id_label1}: {only_in_df1}")
        if only_in_df2:
            msgs.append(f"Only in {id_label2}: {only_in_df2}")
        if type_mismatch:
            # e.g. [('amount','bigint','string'), ...]
            pretty = [f"{c}: {t1} vs {t2}" for c, t1, t2 in type_mismatch]
            msgs.append("Type mismatches (col: df1 vs df2): " + ", ".join(pretty))
        raise ValueError("Schema mismatch. " + " | ".join(msgs))

    # Use df1 column order as canonical
    table_cols = df1.columns
    df2 = df2.select(*df1.columns)

    only_in_df1 = df1.subtract(df2).withColumn("_source", lit(id_label1)).withColumn("keyfield_id", concat_ws("|", *[col(k) for k in key_fields]))
    only_in_df2 = df2.subtract(df1).withColumn("_source", lit(id_label2)).withColumn("keyfield_id", concat_ws("|", *[col(k) for k in key_fields]))

    # Reorder and output
    cols_order = ["keyfield_id", "_source"] + table_cols
    combined = only_in_df1.select(*cols_order).unionByName(only_in_df2.select(*cols_order)).orderBy("keyfield_id", "_source")

    return combined

### convert_columns_to_date

In [None]:
from pyspark.sql import DataFrame
from typing import List

print("Function: convert_columns_to_date")

def convert_columns_to_date(
    df: DataFrame, 
    column_list: List[str], 
    date_format: str = "yyyyMMdd"
) -> DataFrame:
    """
    Converts text columns to date type with proper null handling.
    Ensures dates are >= 1900-01-01. Dates before 1900-01-01 are set to 1900-01-01.
    
    Args:
        df: Input DataFrame
        column_list: List of column names to convert
        date_format: Date format (default 'yyyyMMdd')
        
    Returns:
        DataFrame with columns converted to date type
    """
    min_allowed_date = F.to_date(F.lit("1900-01-01"))
    
    for col_name in column_list:
        converted_date = F.to_date(F.col(col_name).cast("string"), date_format)
        
        df = df.withColumn(
            col_name,
            F.when(
                (F.col(col_name).isNull()) | 
                (F.trim(F.col(col_name)) == "") |
                (F.trim(F.col(col_name)) == "0"),
                F.lit(None).cast("date")
            ).otherwise(
                # Use F.greatest to ensure date >= 1900-01-01
                F.greatest(converted_date, min_allowed_date)
            )
        )
    return df

### trim_leading_zeros

In [1]:
from pyspark.sql import DataFrame
from typing import List

print("Function: trim_leading_zeros")

def trim_leading_zeros(df: DataFrame, column_list: List[str]) -> DataFrame:
    """
    Removes leading zeros from specified columns, preserving '0' for '0000' values.
    
    Args:
        df: Input DataFrame
        column_list: List of column names to process
        
    Returns:
        DataFrame with processed columns
        
    Examples:
        '0001' -> '1'
        '0000' -> '0'
        '0' -> '0'
    """
    for col_name in column_list:
        df = df.withColumn(
            col_name, 
            F.regexp_replace(F.col(col_name), '^0+(?=.)', '')
        )
    return df

### rename_columns

In [None]:
def rename_columns(df, column_mapping):
    """
    Renames columns in DataFrame according to the provided mapping.
    
    Args:
        df: Input DataFrame with original column names
        column_mapping: Dictionary mapping {old_name: new_name}
        
    Returns:
        DataFrame with renamed columns
    """
    df_renamed = df
    for old_name, new_name in column_mapping.items():
        if old_name in df.columns:
            df_renamed = df_renamed.withColumnRenamed(old_name, new_name)
    
    return df_renamed

# Example
#
## Column mapping
#column_mapping = {
#    'grdcode': 'GRD_code',
#    'packagingSpecNo': 'PackagingSpecNo',
#    'quantity': 'Quantity'
#}
#
## Function call
#df_transformed = rename_columns(df_source, column_mapping)

In [None]:
print("Done !")