In [0]:
dbutils.widgets.removeAll()

In [0]:
dbutils.widgets.text("table_name","investor")

In [0]:
table_name = dbutils.widgets.get("table_name")
display(table_name)

In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, trim, regexp_replace, when, to_date, lit

In [0]:
def load_bronze_table(spark, table_name):
    return spark.table(f"workspace.bronze.{table_name}")

In [0]:
def check_not_null(df, pk_cols):
    for pk in pk_cols:
        if df.filter(col(pk).isNull()).count() > 0:
            raise Exception(f"Null value found in primary key column: {pk}")

In [0]:
def check_no_duplicates(df, pk_cols):
    dup_count = df.groupBy(pk_cols).count().filter(col("count") > 1).count()
    if dup_count > 0:
        raise ValueError("Duplicate primary key values found.")

In [0]:
def clean_string_columns(df, string_cols):
    for col_name in string_cols:
        df = df.withColumn(
            col_name,
            regexp_replace(trim(col(col_name)), "[$#%]", "")
        )
    return df

In [0]:
def clean_numeric_columns(df, numeric_cols):
    for col_name in numeric_cols:
        df = df.withColumn(
            col_name,
            when(col(col_name).isNull(), lit(0)).otherwise(col(col_name))
        )
    return df

In [0]:
def clean_date_columns(df, date_cols):
    for col_name in date_cols:
        df = df.withColumn(
            col_name,
            when(
                col(col_name).isNull(),
                to_date(lit("9999-12-31"), "yyyy-MM-dd")
            ).otherwise(
                to_date(col(col_name), "yyyy-MM-dd")
            )
        )
    return df

In [0]:
def fill_na_with_default(df, na_fields):
    for field, default_value, data_type in na_fields:
        if data_type.lower() in ["int", "integer", "bigint", "smallint"]:
            casted_default = lit(int(default_value)).cast(data_type)
        elif data_type.lower() in ["float", "double", "decimal"]:
            casted_default = lit(float(default_value)).cast(data_type)
        elif data_type.lower() in ["date"]:
            casted_default = to_date(lit(default_value), "yyyy-MM-dd")
        elif data_type.lower() in ["timestamp"]:
            casted_default = to_timestamp(lit(default_value))
        else:
            casted_default = lit(default_value).cast(data_type)
        df = df.withColumn(
            field,
            when(col(field).isNull(), casted_default).otherwise(col(field))
        )
    return df

In [0]:
def rename_columns_from_constraints(spark, table_name, schema="workspace.silver"):
    constraints_df = (
        spark.table("workspace.control.enrich_table_constraints")
        .filter(
            (col("table_name") == table_name) &
            (col("is_active") == True) &
            (col("field_rename").isNotNull()) &
            (col("field_rename") != "na")
        )
        .select("field_name", "field_rename")
        .distinct()
    )
    rename_pairs = [(row.field_name, row.field_rename) for row in constraints_df.collect()]
    table_schema = spark.catalog.listColumns(f"{schema}.{table_name}")
    existing_cols = {col.name for col in table_schema}
    for old_col, new_col in rename_pairs:
        if old_col != new_col and old_col in existing_cols and new_col not in existing_cols:
            sql = f'ALTER TABLE {schema}.{table_name} RENAME COLUMN {old_col} TO {new_col}'
            spark.sql(sql)

In [0]:
# def incremental_upsert(spark, df, table_name, pk_fields):
#     silver_table = f"workspace.silver.{table_name}"
#     table_exists = spark.catalog.tableExists(silver_table)
#     if not table_exists:
#         df.write.format("delta").mode("overwrite").saveAsTable(silver_table)
#     else:
#         df.createOrReplaceTempView("staging_view")
#         merge_condition = " AND ".join([f"t.{pk}=s.{pk}" for pk in pk_fields])
#         set_clause = ", ".join([f"t.{c}=s.{c}" for c in df.columns])
#         insert_cols = ", ".join(df.columns)
#         insert_vals = ", ".join([f"s.{c}" for c in df.columns])
#         merge_sql = f"""
#             MERGE INTO {silver_table} AS t
#             USING staging_view AS s
#             ON {merge_condition}
#             WHEN MATCHED THEN UPDATE SET {set_clause}
#             WHEN NOT MATCHED THEN INSERT ({insert_cols}) VALUES ({insert_vals})
#         """
#         spark.sql(merge_sql)

In [0]:
def incremental_upsert(spark, df, table_name, pk_fields):
    from pyspark.sql.functions import lit, current_timestamp

    silver_table = f"workspace.silver.{table_name}"
    table_exists = spark.catalog.tableExists(silver_table)
    if not table_exists:
        if "created_date" not in df.columns:
            df = df.withColumn("created_date", current_timestamp())
        if "created_by" not in df.columns:
            df = df.withColumn("created_by", lit("anoopdk"))
        df.write.format("delta").mode("overwrite").saveAsTable(silver_table)
    else:
        df.createOrReplaceTempView("staging_view")
        merge_condition = " AND ".join([f"t.{pk}=s.{pk}" for pk in pk_fields])
        set_clause = ", ".join([f"t.{c}=s.{c}" for c in df.columns])
        insert_cols = ", ".join(df.columns)
        insert_vals = ", ".join([f"s.{c}" for c in df.columns])
        merge_sql = f"""
            MERGE INTO {silver_table} AS t
            USING staging_view AS s
            ON {merge_condition}
            WHEN MATCHED THEN UPDATE SET {set_clause}
            WHEN NOT MATCHED THEN INSERT ({insert_cols}) VALUES ({insert_vals})
        """
        spark.sql(merge_sql)

In [0]:
# def incremental_upsert(spark, df, table_name, pk_fields):
#     from pyspark.sql.functions import lit, current_timestamp

#     silver_table = f"workspace.silver.{table_name}"
#     table_exists = spark.catalog.tableExists(silver_table)

#     # Add metadata columns if not present
#     metadata_cols = ["created_date", "created_by", "updated_date", "updated_by"]
#     for col_name in metadata_cols:
#         if col_name not in df.columns:
#             df = df.withColumn(col_name, lit(None).cast("timestamp" if "date" in col_name else "string"))

#     if not table_exists:
#         # For initial ingestion, set created_date/current_timestamp, created_by/'anoopdk', updated_date/None, updated_by/None
#         df = (
#             df.withColumn("created_date", current_timestamp())
#               .withColumn("created_by", lit("anoopdk"))
#               .withColumn("updated_date", lit(None).cast("timestamp"))
#               .withColumn("updated_by", lit(None).cast("string"))
#         )
#         df.write.format("delta").mode("overwrite").saveAsTable(silver_table)
#     else:
#         df.createOrReplaceTempView("staging_view")
#         merge_condition = " AND ".join([f"t.{pk}=s.{pk}" for pk in pk_fields])

#         # Columns for update and insert
#         update_set = []
#         for c in df.columns:
#             if c == "updated_date":
#                 update_set.append(f"t.updated_date = current_timestamp()")
#             elif c == "updated_by":
#                 update_set.append(f"t.updated_by = 'anoopdk'")
#             elif c not in ["created_date", "created_by"]:
#                 update_set.append(f"t.{c} = s.{c}")
#         set_clause = ", ".join(update_set)

#         insert_cols = ", ".join(df.columns)
#         insert_vals = []
#         for c in df.columns:
#             if c == "created_date":
#                 insert_vals.append("current_timestamp()")
#             elif c == "created_by":
#                 insert_vals.append("'anoopdk'")
#             elif c in ["updated_date", "updated_by"]:
#                 insert_vals.append("NULL")
#             else:
#                 insert_vals.append(f"s.{c}")
#         insert_vals_str = ", ".join(insert_vals)

#         merge_sql = f"""
#             MERGE INTO {silver_table} t
#             USING staging_view s
#             ON {merge_condition}
#             WHEN MATCHED THEN UPDATE SET {set_clause}
#             WHEN NOT MATCHED THEN INSERT ({insert_cols}) VALUES ({insert_vals_str})
#         """
#         spark.sql(merge_sql)

In [0]:
def process_table(spark, table_name, pk_fields, not_null_fields, na_fields,string_cols, numeric_cols, date_cols):
    df = load_bronze_table(spark, table_name)
    if pk_fields:
        check_not_null(df, pk_fields)
        check_no_duplicates(df, pk_fields)
    if not_null_fields:
        check_not_null(df, not_null_fields)
    if na_fields:
        df = fill_na_with_default(df, na_fields)
    if string_cols:
        df = clean_string_columns(df, string_cols)
    if numeric_cols:
        df = clean_numeric_columns(df, numeric_cols)
    if date_cols:
        df = clean_date_columns(df, date_cols)
    incremental_upsert(spark, df, table_name, pk_fields)
    # rename_columns_from_constraints(spark, table_name)

In [0]:
from pyspark.sql.functions import col

constraints_df = (
    spark.table("workspace.control.enrich_table_constraints")
    .filter((col("table_name") == table_name) & (col("is_active") == True))
)

pk_fields = [row.field_name for row in constraints_df.filter(col("field_constraint") == "primary key").collect()]
not_null_fields = [row.field_name for row in constraints_df.filter(col("field_constraint") == "not null").collect()]
na_fields = [
    (row.field_name, row.default_value, row.field_data_type)
    for row in constraints_df.filter((col("field_constraint") == "na") & (col("default_value").isNotNull())).collect()
]

# table_name = "transaction"
fields_df = (
    spark.table("workspace.control.source_file_details")
    .filter(col("file_name") == table_name)
    .join(
        spark.table("workspace.control.source_file_specifications"),
        "source_id"
    )
    .select("field_name", "data_type")
)

# display(fields_df)

string_cols = [row.field_name for row in fields_df.filter(col("data_type").rlike("(?i)string|char|text")).collect()]
# display(string_cols)
numeric_cols = [row.field_name for row in fields_df.filter(col("data_type").rlike("(?i)int|float|double|decimal|number")).collect()]
# display(numeric_cols)
date_cols = [row.field_name for row in fields_df.filter(col("data_type").rlike("(?i)date|timestamp")).collect()]
# display(date_cols)

In [0]:
process_table(
    spark,
    table_name=table_name,
    pk_fields=pk_fields,
    not_null_fields=not_null_fields,
    na_fields=na_fields,
    string_cols=string_cols,
    numeric_cols=numeric_cols,
    date_cols=date_cols
)