In [0]:
tenant_id = dbutils.secrets.get(scope="KeyVault", key="tenentId")
client_id = dbutils.secrets.get(scope="KeyVault", key="ADB-Mirror-SPN-ClientId")
client_secret = dbutils.secrets.get(scope="KeyVault", key="ADB-Mirror-SPN-Secret")
WHID = dbutils.secrets.get(scope="KeyVault", key="ADB-Mirror-WHID")
WSID = dbutils.secrets.get(scope="KeyVault", key="ADB-Mirror-WSID")

base_url = f"abfss://{WHID}@onelake.dfs.fabric.microsoft.com/{WSID}/Files/LandingZone"
primaryKeys = {"test_catalog.sales_data.orders": "order_id"
            #    ,"test_catalog.sales_data.customers": "customer_id"
               }


spark.conf.set("fs.azure.account.auth.type.onelake.dfs.fabric.microsoft.com", "OAuth")
spark.conf.set("fs.azure.account.oauth.provider.type.onelake.dfs.fabric.microsoft.com",
                "org.apache.hadoop.fs.azurebfs.oauth2.ClientCredsTokenProvider")
spark.conf.set("fs.azure.account.oauth2.client.id.onelake.dfs.fabric.microsoft.com", client_id)
spark.conf.set("fs.azure.account.oauth2.client.secret.onelake.dfs.fabric.microsoft.com", client_secret)
spark.conf.set("fs.azure.account.oauth2.client.endpoint.onelake.dfs.fabric.microsoft.com",
                f"https://login.microsoftonline.com/{tenant_id}/oauth2/token")

In [0]:
target_urls = {}

for full_table_path, pk in primaryKeys.items():
    parts = full_table_path.split(".")
    if len(parts) != 3:
        raise ValueError(f"Invalid table path format: {full_table_path}")

    catalog, schema, table = parts
    target_url = f"{base_url}/{schema}.schema/{table}"
    target_urls[full_table_path] = target_url



In [0]:
import json
from pyspark.sql import SparkSession
from py4j.protocol import Py4JJavaError
from pyspark.sql import functions as F
from pyspark.sql.functions import lit, when, col
from datetime import datetime

def get_last_processed_version(table_name: str):
    """Returns last processed version if exists, else None."""
    try:
        df = spark.table("test_catalog.config.cdc_control") \
                  .filter(F.col("table_name") == table_name)
        if df.count() == 0:
            return False  # no record → initial load
        else:
            return df.select("last_processed_version").collect()[0][0]
    except Exception:
        # control table itself might not exist on first run
        return False


def write_metadata_file_safe(target_dir: str, table_name: str):
    key_col = primaryKeys.get(table_name)
    metadata = {"KeyColumns": [key_col]} if key_col else {"KeyColumns": []}
    metadata_json = json.dumps(metadata, indent=4)

    dbutils.fs.put(f"{target_dir}/_metadata.json", metadata_json, overwrite=True)
    print(f"Metadata file written to {target_dir}/_metadata.json")





In [0]:
def upsert_cdc_log(table_name: str, last_processed_version: int, last_processed_timestamp):
    """Insert or update CDC control table with latest processed version and timestamp."""
    last_processed_timestamp_str = str(last_processed_timestamp)
    spark.sql(f"""
        MERGE INTO test_catalog.config.cdc_control AS target
        USING (SELECT
                  '{table_name}' AS table_name,
                  {last_processed_version} AS last_processed_version,
                  TIMESTAMP('{last_processed_timestamp_str}') AS last_processed_timestamp,
                  current_timestamp() AS updated_at
               ) AS source
        ON target.table_name = source.table_name
        WHEN MATCHED THEN
          UPDATE SET
            target.last_processed_version = source.last_processed_version,
            target.last_processed_timestamp = source.last_processed_timestamp,
            target.updated_at = source.updated_at
        WHEN NOT MATCHED THEN
          INSERT (table_name, last_processed_version, last_processed_timestamp, updated_at)
          VALUES (source.table_name, source.last_processed_version, source.last_processed_timestamp, source.updated_at)
    """)
    print(f"CDC control updated → {table_name}: version={last_processed_version}")


In [0]:
def full_initial_load_single_parquet(target_dir: str, table_name: str):

    write_metadata_file_safe(target_dir, table_name)

    df = spark.table(table_name)
    local_tmp_dir = "dbfs:/tmp/local_parquet"
    dbutils.fs.rm(local_tmp_dir, recurse=True)

    df.coalesce(1).write.mode("overwrite").parquet(local_tmp_dir)

    files = [f.path for f in dbutils.fs.ls(local_tmp_dir) if f.name.endswith(".parquet")]
    if not files:
        raise Exception("No parquet file generated in temp_dir")

    temp_file = files[0]
    target_file = f"{target_dir}/00000000000000000001.parquet"

    # Copy only the parquet file (skip _SUCCESS, logs)
    dbutils.fs.cp(temp_file, target_file)
    dbutils.fs.rm(local_tmp_dir, recurse=True)

    # Get latest version & timestamp and update control log
    latest = spark.sql(f"DESCRIBE HISTORY {table_name}").orderBy(F.desc("version")).first()
    upsert_cdc_log(table_name, latest['version'], latest['timestamp'])

    print(f"Full load complete: {target_file}")


In [0]:
import re

def get_next_parquet_number(target_dir: str) -> str:
    try:
        files = [f.name for f in dbutils.fs.ls(target_dir) if f.name.endswith(".parquet")]
        if not files:
            return

        numbers = []
        for f in files:
            match = re.match(r"(\d+)\.parquet", f)
            if match:
                numbers.append(int(match.group(1)))

        if not numbers:
            next_num = 1
        else:
            next_num = max(numbers) + 1
        return f"{next_num:020d}.parquet"
    except Exception as e:
        print(f"Error while finding next parquet number for {target_dir}: {e}")
        return


In [0]:
from pyspark.sql import functions as F
from pyspark.sql.functions import col, lit, when

def run_delta_load(
    table_name: str,
    target_dir: str
):
    """
    Runs delta load for the given source table using Unity Catalog CDC (table_changes).
    Writes delta data as parquet (single file) with __rowMarker__ column.
    Updates CDC log table via upsert_cdc_log().
    """
    control_table = "test_catalog.config.cdc_control"
    try:
        ctrl_df = spark.read.table(control_table).filter(col("table_name") == table_name)
        if ctrl_df.count() == 0:
            print(f"No control entry found for {table_name}. Please run initial load first.")
            return

        last_version = ctrl_df.select("last_processed_version").collect()[0][0]
        print(f"Last processed version for {table_name}: {last_version}")

        # --- Get current latest version of the table ---
        latest = spark.sql(f"DESCRIBE HISTORY {table_name}").orderBy(F.desc("version")).first()
        new_version = latest['version']
        new_timestamp = latest['timestamp']

        # --- Validate versions before running CDC ---
        if last_version >= new_version:
            print(f"No new changes. Latest version ({new_version}) <= last processed ({last_version}).")
            return

        # --- FIX: table_changes is inclusive → start from last_version + 1 ---
        start_version = int(last_version) + 1
        end_version = int(new_version)

        query = f"SELECT * FROM table_changes('{table_name}', {start_version}, {end_version})"
        changes_df = spark.sql(query)

        # --- FIX: Unity Catalog doesn't support .rdd, so use .limit(1).count() ---
        if changes_df.limit(1).count() == 0:
            print("No new changes found.")
            return

        changes_df = (
            changes_df
            .withColumn(
                "__rowMarker__",
                when(col("_change_type") == "insert", lit(0))
                .when(col("_change_type") == "update_postimage", lit(1))
                .when(col("_change_type") == "delete", lit(2))
                .otherwise(lit(None))
            )
            .drop("_change_type", "_commit_version", "_commit_timestamp")
        )
        changes_df = changes_df.filter(changes_df["__rowMarker__"].isNotNull()).orderBy(col("__rowMarker__"))

        local_tmp_dir = "dbfs:/tmp/local_parquet"
        dbutils.fs.rm(local_tmp_dir, recurse=True)

        changes_df.coalesce(1).write.mode("overwrite").parquet(local_tmp_dir)

        files = [f.path for f in dbutils.fs.ls(local_tmp_dir) if f.name.endswith(".parquet")]
        if not files:
            raise Exception("No parquet file generated in temp_dir")

        temp_file = files[0]
        next_file = get_next_parquet_number(target_dir)
        target_file = f"{target_dir}/{next_file}"

        dbutils.fs.cp(temp_file, target_file)
        dbutils.fs.rm(local_tmp_dir, recurse=True)

        print(f"Delta parquet file created: {target_file}")

        upsert_cdc_log(table_name, new_version, new_timestamp)
        print(f"Delta load completed for {table_name}. Updated version → {new_version}")

    except Exception as e:
        print(f"Delta load failed for {table_name}: {e}")


In [0]:
for full_table_name, pk in primaryKeys.items():
    if not get_last_processed_version(full_table_name):
        full_initial_load_single_parquet(target_urls[full_table_name], full_table_name)
    else:
        print(f"Skipping full load for {full_table_name} as it has already been processed")
        run_delta_load(full_table_name, target_urls[full_table_name])

