In [None]:
# Widgets
dbutils.widgets.text("jdbc_hostname", "")
dbutils.widgets.text("jdbc_port", "")
dbutils.widgets.text("jdbc_db", "")
dbutils.widgets.text("user_name", "")
dbutils.widgets.text("table_name", "")
dbutils.widgets.text("dest_catalog", "")
dbutils.widgets.text("scope", "")
dbutils.widgets.text("dest_schema", "")
dbutils.widgets.text("src_schema", "")
dbutils.widgets.text("db_type", "")
dbutils.widgets.text("overwrite", "")
dbutils.widgets.text("incremental_field", "")

In [None]:
# Imports and logging setup
import logging
import typing as t
from delta.tables import DeltaTable
from pyspark.sql import DataFrame
from pyspark.sql.utils import AnalysisException
from pyspark.sql.functions import lower, collect_list, date_format

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:
# Main execution
db_type = dbutils.widgets.get("db_type")
jdbc_hostname = dbutils.widgets.get("jdbc_hostname")
jdbc_port = dbutils.widgets.get("jdbc_port")
jdbc_db = dbutils.widgets.get("jdbc_db")
user_name = dbutils.widgets.get("user_name")
scope = dbutils.widgets.get("scope")
password = dbutils.secrets.get(scope=scope, key=user_name)
dest_catalog = dbutils.widgets.get("dest_catalog")
dest_schema = dbutils.widgets.get("dest_schema")
table_key_name = dbutils.widgets.get("table_key_name")
incremental_field = dbutils.widgets.get("incremental_field")

jdbc_url = get_jdbc_url(db_type, jdbc_hostname, jdbc_port, jdbc_db)
jdbc_driver = get_jdbc_driver(db_type)
connection_props = get_connection_properties(user_name, password, jdbc_driver)

overwrite = dbutils.widgets.get("overwrite").lower() == "true"

try:
    check_and_set_catalog_schema(dest_catalog, dest_schema)
    process_table(
        jdbc_url,
        connection_props,
        dest_catalog,
        dest_schema,
        dbutils.widgets.get("table_name"),
        dbutils.widgets.get("src_schema"),
        db_type,
        overwrite,
        table_key_name,
        incremental_field,
    )
except ValueError as e:
    logger.error(f"Error setting up catalog and schema: {str(e)}")
    raise

In [None]:
# Helper functions
def get_connection_properties(
    user_name: str, 
    password: str, 
    driver: str
    ) -> t.Dict[str, str]:
    """
    Create a dictionary of JDBC connection properties.
    """
    return {
        "user": user_name,
        "password": password,
        "driver": driver,
        "trustServerCertificate": "true",
        "num_partitions": "25",
    }


def get_jdbc_url(
    db_type: str, 
    hostname: str, 
    port: str, db: str
    ) -> str:
    """
    Generate a JDBC URL based on the database type and connection details.
    """
    if db_type == "sqlserver":
        return f"jdbc:sqlserver://{hostname}:{port};database={db}"
    elif db_type == "db2":
        return f"jdbc:db2://{hostname}:{port}/{db}"
    else:
        raise ValueError(f"Unsupported database type: {db_type}")


def get_jdbc_driver(db_type: str) -> str:
    """
    Get the JDBC driver class name for the specified database type.
    """
    if db_type == "sqlserver":
        return "com.microsoft.sqlserver.jdbc.SQLServerDriver"
    elif db_type == "db2":
        return "com.ibm.db2.jcc.DB2Driver"
    else:
        raise ValueError(f"Unsupported database type: {db_type}")


def read_jdbc_data(
    jdbc_url: str, 
    table_name: str, 
    connection_props: t.Dict[str, str]
    ) -> DataFrame:
    """
    Read data from a JDBC source into a Spark DataFrame.
    """
    return spark.read.jdbc(url=jdbc_url, table=table_name, properties=connection_props)


def remove_spaces_from_column_headers(df: DataFrame) -> DataFrame:
    """
    Remove spaces and parentheses from column names in a Spark DataFrame.
    """
    for col_name in df.columns:
        new_col_name = col_name.replace(" ", "").replace("(", "").replace(")", "")
        df = df.withColumnRenamed(col_name, new_col_name)
    return df


def get_primary_keys(
    jdbc_url: str,
    schema: str,
    table_name: str,
    connection_props: t.Dict[str, str],
    db_type: str,
    ) -> t.Dict[str, t.List[str]]:
    """
    Retrieve primary key information for a given table.
    """
    if db_type == "sqlserver":
        query = f"""
        (SELECT kcu.TABLE_NAME, kcu.COLUMN_NAME 
        FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS tc
        JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE AS kcu
        ON tc.CONSTRAINT_TYPE = 'PRIMARY KEY' AND tc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME AND tc.TABLE_NAME = kcu.TABLE_NAME
        WHERE tc.TABLE_SCHEMA = '{schema}' AND tc.TABLE_NAME = '{table_name}') AS primary_key_info
        """
        df = read_jdbc_data(jdbc_url, query, connection_props)
        result = (
            df.groupBy(lower(df.TABLE_NAME).alias("table_name"))
            .agg(collect_list("COLUMN_NAME").alias("column_names"))
            .collect()
        )
        return {row["table_name"]: row["column_names"] for row in result}

    elif db_type == "db2":
        query = f"""
        (WITH table_index_p AS (
        SELECT
            si.tbname,
            RTRIM(si.name) AS index_name,
            RTRIM(si.creator) AS index_schema,
            si.uniquerule,
            si.colnames
        FROM sysibm.sysindexes si
        WHERE lower(si.tbname) = '{table_name.lower()}'
        AND lower(si.tbcreator) = '{schema.lower()}'
        AND si.uniquerule = 'P'
        ),
        table_index_d AS (
        SELECT
            si.tbname,
            RTRIM(si.name) AS index_name,
            RTRIM(si.creator) AS index_schema,
            si.uniquerule,
            si.colnames
        FROM sysibm.sysindexes si
        LEFT JOIN table_index_p tip ON rtrim(si.tbname) = tip.tbname 
        WHERE lower(si.tbname) = lower('{table_name}')
        AND lower(si.tbcreator) = lower('{schema}')
        AND si.uniquerule = 'D'
        and tip.tbname IS NULL 
        )
        SELECT tbname as TABLE_NAME, replace(substr(colnames,2),'+',',') AS COLUMN_NAME FROM table_index_p 
        UNION 
        SELECT tbname as TABLE_NAME, replace(substr(colnames,2),'+',',') AS COLUMN_NAME FROM table_index_d ) AS primary_key_info
        """
        df = read_jdbc_data(jdbc_url, query, connection_props)
        pk_list = (
            df.select("COLUMN_NAME")
            .rdd.flatMap(lambda x: x[0].split(","))
            .distinct()
            .collect()
        )
        table_name_src = f"{schema.lower()}.{table_name.lower()}"
        return {table_name_src: pk_list}

    else:
        raise ValueError(f"Unsupported database type: {db_type}")


def check_table_exists(
    catalog: str, 
    schema: str, 
    table_name: str
    ) -> bool:
    """
    Check if a table exists in the specified catalog and schema.
    """
    full_table_name = f"{catalog}.{schema}.{table_name}"
    try:
        # Method 1: Using spark.catalog.tableExists
        return spark.catalog.tableExists(full_table_name)
    except Exception as e:
        logger.warning(f"tableExists method failed for {full_table_name}: {str(e)}. Trying alternative method.")
    
    try:
        # Method 2: Using SHOW TABLES SQL command
        df = spark.sql(f"SHOW TABLES IN {catalog}.{schema} LIKE '{table_name}'")
        return df.filter(df.tableName == table_name).count() > 0
    except AnalysisException as e:
        logger.error(f"Error checking existence of table {full_table_name}: {str(e)}")
        return False


def write_delta_table(
    df: DataFrame,
    mode: str,
    catalog: str,
    schema: str,
    table: str,
    partition_by: t.Optional[t.List[str]] = None,
    ):
    """
    Write a Spark DataFrame to a Delta table.
    """
    writer = df.write.mode(mode).option("inferSchema", "true").format("delta")
    if partition_by:
        writer = writer.partitionBy(*partition_by)
    writer.saveAsTable(f"{catalog}.{schema}.{table}")


def merge_delta_table(
    target_df: DeltaTable,
    source_df: DataFrame,
    merge_condition: str,
    update_columns: t.Dict[str, str],
    ):
    """
    Merge a source DataFrame into a target Delta table.
    """
    return (
        target_df.alias("target")
        .merge(source_df.alias("source"), merge_condition)
        .whenMatchedUpdate(set=update_columns)
        .whenNotMatchedInsert(
            values={col: f"source.{col}" for col in source_df.columns}
        )
        .execute()
    )


def catalog_exists(catalog: str) -> bool:
    """
    Check if a catalog exists.
    """
    try:
        df = spark.sql("SHOW CATALOGS")
        return df.filter(df.catalog == catalog).count() > 0
    except AnalysisException as e:
        logger.error(f"Error checking catalog {catalog}: {str(e)}")
        return False


def schema_exists(
    catalog: str, 
    schema: str
    ) -> bool:
    """
    Check if a schema exists in the specified catalog.
    """
    try:
        df = spark.sql(f"SHOW SCHEMAS IN {catalog}")
        return df.filter(df.databaseName == schema).count() > 0
    except AnalysisException as e:
        logger.error(f"Error checking schema {schema} in catalog {catalog}: {str(e)}")
        return False


def check_and_set_catalog_schema(
    catalog: str, 
    schema: str
    ) -> None:
    """
    Check if the catalog and schema exist, and set them as the current catalog and schema.
    """
    if not catalog_exists(catalog):
        raise ValueError(f"The catalog '{catalog}' does not exist.")
    logger.info(f"The catalog '{catalog}' exists.")

    if not schema_exists(catalog, schema):
        raise ValueError(
            f"The schema '{schema}' does not exist in catalog '{catalog}'."
        )
    logger.info(f"The schema '{schema}' exists in catalog '{catalog}'.")

    spark.sql(f"USE {catalog}.{schema}")
    logger.info(f"Now using {catalog}.{schema}")


def create_merge_condition(pk_columns: t.List[str]) -> str:
    """
    Create a merge condition string based on primary key columns.
    """
    return " AND ".join([f"target.{col} = source.{col}" for col in pk_columns])


def create_update_columns(
    source_df: DataFrame, 
    pk_columns: t.List[str]
    ) -> t.Dict[str, str]:
    """
    Create a dictionary of columns to update, excluding primary key columns.
    """
    return {
        f"target.{col}": f"source.{col}"
        for col in source_df.columns
        if col not in pk_columns
    }


def process_table(
    jdbc_url: str,
    connection_props: t.Dict[str, str],
    dest_catalog: str,
    dest_schema: str,
    table_name: str,
    src_schema: str,
    db_type: str,
    overwrite: bool,
    table_key_name: str,
    incremental_field: str,
    ):
    """
    Process a table by either creating, overwriting, or incrementally updating it in the destination.
    This function handles different scenarios:
    1. Full load (overwrite or new table)
    2. Incremental update (if incremental_field is specified)
    3. Full merge (if no incremental_field and not overwriting)
    """

    logger.info(f"Processing table: {table_name}")
    table_name_src = f"{src_schema.lower()}.{table_name}"
    target_table_name = f"{dest_catalog}.{dest_schema}.{table_name.lower()}"
    table_exists = check_table_exists(dest_catalog, dest_schema, table_name.lower())
    if not table_exists or overwrite:
        source_df = read_jdbc_data(jdbc_url, table_name_src, connection_props)
        source_df = remove_spaces_from_column_headers(source_df)
        write_delta_table(
            source_df, "overwrite", dest_catalog, dest_schema, table_name.lower()
        )
        logger.info(f"Table '{target_table_name}' has been created/overwritten.")
    else:
        if incremental_field:
            max_date_df = spark.sql(
                f"SELECT timestamp FROM (DESCRIBE HISTORY {target_table_name} LIMIT 1)"
            )
            if max_date_df.count() > 0:
                max_date = max_date_df.collect()[0]["timestamp"].strftime("%Y-%m-%d")
                query = f"(SELECT * FROM {table_name_src} WHERE CAST({incremental_field} AS DATE) > CAST('{max_date}' AS DATE))"
            else:
                query = table_name_src
        else:
            query = table_name_src
        source_df = spark.read.jdbc(url=jdbc_url, table=query, properties=connection_props)
        source_df = remove_spaces_from_column_headers(source_df)
        if source_df.count() > 0:
            target_df = DeltaTable.forName(spark, target_table_name)
            pk_columns = [table_key_name] if table_key_name else get_primary_keys(
                jdbc_url, src_schema, table_name, connection_props, db_type
            )
            if not pk_columns:
                raise ValueError(
                    f"No Primary Key found for {table_name}, please specify key."
                )
            merge_condition = create_merge_condition(pk_columns)
            update_columns = create_update_columns(source_df, pk_columns)
            merge_delta_table(target_df, source_df, merge_condition, update_columns)
            log_message = "incrementally updated" if incremental_field else "fully merged"
            logger.info(f"Table '{target_table_name}' has been {log_message}.")
        else:
            logger.info(f"No new data to update in '{target_table_name}'.")
    