In [0]:
# Robust end-to-end: multi-column PK support (no hardcoding)
import json, traceback, datetime
from pyspark.sql.functions import col, trim, to_timestamp, year, lit, concat_ws, sha2, current_timestamp, coalesce
from delta.tables import DeltaTable

# ---------------------------
# Widgets (ADF will pass these)
# ---------------------------
dbutils.widgets.text("domain", "")
dbutils.widgets.text("file_name", "")            # e.g. Sales.Currency or Sales.Currency.csv
dbutils.widgets.text("column_list", "")          # JSON array OR CSV (optional)
dbutils.widgets.text("pk_columns", "")           # JSON array OR CSV, e.g. '["CurrencyAlternateKey"]' or 'col1,col2'
dbutils.widgets.text("year_column", "")          # optional
dbutils.widgets.text("table_name", "")           # optional override
dbutils.widgets.text("batch_name", "")           # optional batch id/name from ADF
dbutils.widgets.text("direct_account_key", "")   # optional storage account key (base64)
dbutils.widgets.text("Source_path", "")          # optional full path (preferred)
dbutils.widgets.text("Target_path", "")
dbutils.widgets.text("include_layer", "false")   # optional (true/false)
dbutils.widgets.text("layer", "Bronze")          # optional layer name
dbutils.widgets.text("merge_flag", "true")
dbutils.widgets.text("incremental_flag", "false")

# ---------------------------
# Read widget values
# ---------------------------
domain = dbutils.widgets.get("domain").strip()
file_name = dbutils.widgets.get("file_name").strip()
column_list_widget = dbutils.widgets.get("column_list").strip()
pk_columns_widget = dbutils.widgets.get("pk_columns").strip()
year_column = dbutils.widgets.get("year_column").strip()
table_name = dbutils.widgets.get("table_name").strip()
batch_name = dbutils.widgets.get("batch_name").strip()
direct_account_key = dbutils.widgets.get("direct_account_key").strip()
Source_path = dbutils.widgets.get("Source_path").strip()
Target_path = dbutils.widgets.get("Target_path").strip()
include_layer = dbutils.widgets.get("include_layer").strip().lower() in ("true","1","yes","y")
layer = dbutils.widgets.get("layer").strip()
merge_flag = dbutils.widgets.get("merge_flag").strip().lower() in ("true","1","yes","y")
incremental_flag = dbutils.widgets.get("incremental_flag").strip().lower() in ("true","1","yes","y")

print("Raw widget summary:")
print(" domain:", domain, " file_name:", file_name, " Source_path:", Source_path)
print(" Target_path:", Target_path, " include_layer:", include_layer, " layer:", layer)
print(" merge_flag:", merge_flag, " incremental:", incremental_flag)
print(" pk_columns(widget):", pk_columns_widget)

# ---------------------------
# Helper functions
# ---------------------------
def parse_column_list(txt):
    """Accept JSON array or comma-separated string; return list of trimmed column names."""
    if not txt:
        return []
    txt = txt.strip()
    # try JSON
    try:
        parsed = json.loads(txt)
        if isinstance(parsed, list):
            return [str(x).strip() for x in parsed if str(x).strip()]
    except Exception:
        # fallback CSV
        return [p.strip() for p in txt.split(",") if p.strip()]
    return []

def is_absolute_path(p):
    if not p or not isinstance(p, str):
        return False
    low = p.lower()
    return low.startswith("abfss://") or low.startswith("wasbs://") or low.startswith("/dbfs/") or low.startswith("adl://") or low.startswith("s3://")

# parse pk columns from widget
pk_columns = parse_column_list(pk_columns_widget)
print("Parsed pk_columns:", pk_columns)

# ---------------------------
# Determine storage account for spark.conf config (best-effort)
# ---------------------------
storage_account = None
for p in (Source_path, Target_path):
    if p and "@" in p and ("dfs.core.windows.net" in p or "blob.core.windows.net" in p):
        try:
            storage_account = p.split("@",1)[1].split(".")[0]
            break
        except:
            pass
if not storage_account:
    storage_account = "scrgvkrmade"  # fallback; change if different

# configure storage keys if provided
if direct_account_key:
    k = direct_account_key.strip()
    if (k.startswith('"') and k.endswith('"')) or (k.startswith("'") and k.endswith("'")):
        k = k[1:-1]
    k = k.strip()
    spark.conf.set(f"fs.azure.account.key.{storage_account}.blob.core.windows.net", k)
    spark.conf.set(f"fs.azure.account.key.{storage_account}.dfs.core.windows.net", k)
    print("Set storage key for account:", storage_account)
else:
    print("No direct_account_key provided â€” ensure cluster MSI/secret/mount is configured.")

# ---------------------------
# Auto-build Source_path if missing (best-effort)
# ---------------------------
if not is_absolute_path(Source_path):
    container = "project"
    acct = storage_account
    root_layer = layer if layer else "Bronze"
    target_folder = table_name if table_name else file_name
    if not target_folder:
        raise RuntimeError("Source_path not provided and neither table_name nor file_name are set. Provide Source_path or file_name/table_name.")
    target_folder = target_folder.strip("/ ")
    Source_path = f"abfss://{container}@{acct}.dfs.core.windows.net/{root_layer}/{domain}/{target_folder}/"
    print("Auto-built Source_path ->", Source_path)
else:
    print("Using provided Source_path ->", Source_path)

# final validate
if not is_absolute_path(Source_path):
    raise RuntimeError("Final Source_path is not absolute: " + str(Source_path))

# ---------------------------
# READ SOURCE parquet (recursive)
# ---------------------------
try:
    print("Reading parquet data from ->", Source_path)
    src_df = (spark.read
              .option("mergeSchema","true")
              .option("recursiveFileLookup","true")
              .parquet(Source_path))
    print("Read OK. Columns:", src_df.columns)
    display(src_df.limit(5))
except Exception:
    traceback.print_exc()
    raise RuntimeError("Failed to read parquet from Source_path: " + str(Source_path))

# ---------------------------
# CLEAN / TRANSFORM using ALL available columns (generic)
# ---------------------------
df = src_df
# trim strings
string_cols = [c for c,d in df.dtypes if d == "string"]
for c in string_cols:
    df = df.withColumn(c, trim(col(c)))
# year -> timestamp and _year if available
if year_column and year_column in df.columns:
    try:
        df = df.withColumn(year_column, to_timestamp(col(year_column)))
        df = df.withColumn("_year", year(col(year_column)))
    except Exception:
        if "_year" not in df.columns:
            df = df.withColumn("_year", lit(datetime.datetime.utcnow().year))
else:
    if "_year" not in df.columns:
        df = df.withColumn("_year", lit(datetime.datetime.utcnow().year))
# audit columns
if "__ingest_ts" not in df.columns:
    df = df.withColumn("__ingest_ts", current_timestamp())
if "__source_file" not in df.columns:
    df = df.withColumn("__source_file", lit(Source_path.split("/")[-1] or Source_path))
if "__source_path" not in df.columns:
    df = df.withColumn("__source_path", lit(Source_path))
if "__batch_id" not in df.columns:
    dbatch = batch_name if batch_name else "Batch-" + datetime.datetime.utcnow().strftime("%Y%m%dT%H%M%S")
    df = df.withColumn("__batch_id", lit(dbatch))

# compute generic __row_hash using all columns (string-cast)
concat_expr = concat_ws("||", *[col(c).cast("string") for c in df.columns])
df = df.withColumn("__row_hash", sha2(concat_expr, 256))

print("After generic cleaning. Columns now:", df.columns)
display(df.limit(5))

# ---------------------------
# DEDUPE (use provided multi-column PK if any are valid)
# ---------------------------
# determine which of the user-provided PK columns actually exist in df
valid_pk = [p for p in pk_columns if p in df.columns]
if valid_pk:
    # drop duplicates by all provided PK columns (multi-column composite key)
    try:
        before = df.count()
    except Exception:
        before = None
    df = df.dropDuplicates(valid_pk)
    if before is not None:
        after = df.count()
        print(f"Dropped duplicates by PK columns {valid_pk}: {before-after} rows removed.")
    else:
        print(f"Dropped duplicates by PK columns {valid_pk}.")
else:
    # fallback: dedupe by __row_hash
    df = df.dropDuplicates(["__row_hash"])
    print("No valid PK columns found in data; deduped by __row_hash.")

display(df.limit(5))
print("Final DF columns (post-dedupe):", df.columns)

# ---------------------------
# WRITE / MERGE into Delta using the valid_pk composite key (if present)
# ---------------------------
print("Preparing to write/merge to target Delta:", Target_path)
delta_exists = True
try:
    _ = spark.read.format("delta").load(Target_path)
except Exception:
    delta_exists = False

if not delta_exists:
    print("Target Delta not found -> creating initial Delta.")
    (df.write
        .format("delta")
        .mode("overwrite")
        .option("overwriteSchema", "true")
        .partitionBy("_year")
        .save(Target_path))
    print("Initial Delta created at:", Target_path)
else:
    # choose merge key(s): use valid_pk if present, else fallback to __row_hash
    merge_pks = valid_pk[:] if valid_pk else []
    if not merge_pks:
        merge_pks = ["__row_hash"]
    join_cond = " AND ".join([f"target.`{c}` = source.`{c}`" for c in merge_pks])
    print("Merging using keys:", merge_pks)
    dt = DeltaTable.forPath(spark, Target_path)
    dt.alias("target").merge(
        df.alias("source"),
        join_cond
    ).whenMatchedUpdateAll() \
     .whenNotMatchedInsertAll() \
     .execute()
    print("MERGE completed into:", Target_path)

# ---------------------------
# final sample / validation
# ---------------------------
try:
    tgt = spark.read.format("delta").load(Target_path)
    print("Target approx row count:", tgt.count())
    display(tgt.limit(5))
except Exception as e:
    print("Could not read target after write/merge:", e)
    try:
        for f in dbutils.fs.ls(Target_path):
            print("-", f.path)
    except Exception as e2:
        print("Also failed to list target path:", e2)

print("End.")
