In [0]:
from pyspark.sql.functions import (
    col, lit, expr, to_timestamp,
    sum as spark_sum
)

In [0]:
# dbutils.widgets.text("TableName", "")
# dbutils.widgets.text("StageId", "")
# dbutils.widgets.text("ProcessInstanceId", "")
# dbutils.widgets.text("ProcessQueueId", "")

# TableName = dbutils.widgets.get("TableName")
# StageId = (dbutils.widgets.get("StageId"))
# ProcessInstanceId = (dbutils.widgets.get("ProcessInstanceId"))
# ProcessQueueId = (dbutils.widgets.get("ProcessQueueId"))


In [0]:
StageId = int(dbutils.widgets.get("StageId"))
ProcessInstanceId = int(dbutils.widgets.get("ProcessInstanceId"))
ProcessQueueId = int(dbutils.widgets.get("ProcessQueueId"))
TableName = dbutils.widgets.get("TableName")

In [0]:
# Mark current table as InProgress
if StageId == 2:
    spark.sql(f"""
        update control.processqueue
        set ProcessStatus = 'InProgress',
            ProcessStartTime = current_timestamp()
        where StageId = {StageId}
            and ProcessInstanceId = {ProcessInstanceId}
            and ProcessQueueId = {ProcessQueueId}
            and TableName = '{TableName}';
    """)
else:
    raise Exception(f"Stage Id is not relavent to R2B-transformation for table: {TableName}")

In [0]:
rawdf = (
    spark.read
    .option("header", True)
    .option("inferSchema", True)
    .csv(f"/Volumes/workspace/raw/raw-volume/{TableName}/{TableName}.csv")
)

In [0]:
Status = False

bronze_path = f"/Volumes/workspace/bronze/"

# -------------------------------------------------
# READ METADATA
# -------------------------------------------------
metadata_df = (
    spark.table("workspace.metadata.TableColumnDetails")
    .filter(col("TableName") == TableName)
)

metadata_cols = [r.ColumnName for r in metadata_df.collect()]
try:
    # -------------------------------------------------
    # COLUMN NAME NORMALIZATION & ALIGNMENT
    # -------------------------------------------------
    def normalize(c):
        return c.replace(" ", "").replace("_", "").lower()

    mapping = {}
    for raw_col in rawdf.columns:
        for meta_col in metadata_cols:
            if normalize(raw_col) == normalize(meta_col):
                mapping[raw_col] = meta_col
                break

    for raw_col, meta_col in mapping.items():
        rawdf = rawdf.withColumnRenamed(raw_col, meta_col)

    # Drop extra columns
    rawdf = rawdf.select([c for c in rawdf.columns if c in metadata_cols])

    # -------------------------------------------------
    # SCHEMA ENFORCEMENT + DATE HANDLING (FIXED)
    # -------------------------------------------------
    for row in metadata_df.collect():
        col_name = row.ColumnName
        spark_type = row.DataType.lower()

        if col_name not in rawdf.columns:
            # Add missing column
            rawdf = rawdf.withColumn(
                col_name,
                lit(None).cast(spark_type)
            )
        else:
            if spark_type == "timestamp":
                # ✅ Correct date parsing (DD/MM/YYYY)
                rawdf = rawdf.withColumn(
                col_name,
                expr(f"try_to_timestamp({col_name}, 'dd/MM/yyyy')")
                )
            else:
                # Safe cast for all other datatypes
                rawdf = rawdf.withColumn(
                    col_name,
                    expr(f"try_cast({col_name} AS {spark_type})")
                )

    # -------------------------------------------------
    # NULL VALIDATION (PK + NOT NULL)
    # -------------------------------------------------
    critical_cols = [
        r.ColumnName
        for r in metadata_df
            .filter((col("IsNullable") == 0) | (col("IsPrimaryKey") == "Y"))
            .select("ColumnName")
            .toLocalIterator()
    ]

    if critical_cols:
        null_exprs = [
            spark_sum(col(c).isNull().cast("int")).alias(c)
            for c in critical_cols
        ]

        null_counts = rawdf.select(null_exprs).collect()[0].asDict()
        violations = [c for c, cnt in null_counts.items() if cnt > 0]

        if violations:
            raise Exception(
                f"Null values found in non-nullable columns: {violations}"
            )

    # -------------------------------------------------
    # DUPLICATE REMOVAL (PRIMARY KEY BASED)
    # -------------------------------------------------
    pk_cols = [
        r.ColumnName
        for r in metadata_df
            .filter(col("IsPrimaryKey") == "Y")
            .select("ColumnName")
            .toLocalIterator()
    ]

    if pk_cols:
        rawdf = rawdf.dropDuplicates(pk_cols)

    # -------------------------------------------------
    # WRITE TO BRONZE
    # -------------------------------------------------
    bronze_table = f"workspace.bronze.{TableName}"
    (
        rawdf.write
        .format("delta")
        .mode("append")   # or overwrite
        .saveAsTable(bronze_table)
    )
    status = True
    print(f"✅ Transformation completed for table: {TableName}")

except Exception as e:
    print(f"Error: {e}")
    status = False


In [0]:
# Mark file as Success/Failed
if status == True:
    spark.sql(f"""
        UPDATE control.processqueue
        SET
            ProcessStatus = 'Succeeded',
            ProcessEndTime = current_timestamp(),
            ProcessDuration = CAST(
                (unix_timestamp(current_timestamp()) - unix_timestamp(ProcessStartTime)) / 60
                AS BIGINT
            )
        WHERE
            StageId = {StageId}
            AND ProcessInstanceId = {ProcessInstanceId}
            AND ProcessQueueId = {ProcessQueueId}
            AND TableName = '{TableName}'
            """)
    print(f"{TableName} Marked as Successful")
elif status == False:
        spark.sql(f"""
        UPDATE control.processqueue
        SET
            ProcessStatus = 'Failed',
            ProcessEndTime = current_timestamp(),
            ProcessDuration = CAST(
                (unix_timestamp(current_timestamp()) - unix_timestamp(ProcessStartTime)) / 60
                AS BIGINT
            )
        WHERE
            StageId = {StageId}
            AND ProcessInstanceId = {ProcessInstanceId}
            AND ProcessQueueId = {ProcessQueueId}
            AND TableName = '{TableName}'
            """)
        print(f"{TableName} Marked as Failed")
        raise Exception(f"Hard failure: {TableName} Failure detected")