In [None]:
# Bronze→Silver→Gold transform, incremental load with watermark

# ================== nb_fmcg_medallion ==================
from pyspark.sql import functions as F, types as T
from notebookutils import mssparkutils
import json
from datetime import datetime

spark.conf.set("spark.sql.shuffle.partitions", "64")

# --------- 0) Environment selection ---------
try:
    env = mssparkutils.env.getJobTag("env")
    if not env:
        env = "dev"
except:
    env = "dev"
print(f"Environment: {env}")

# --------- 1) Load env config ---------
cfg_path = f"Files/config/{env}.json"
cfg_str  = mssparkutils.fs.head(cfg_path, 5_000_000)
cfg      = json.loads(cfg_str)

SRC_GLOB = cfg["source_glob"]
TBL      = cfg["tables"]
DQ_OPTS  = cfg.get("dq", {})
JDBC     = cfg.get("jdbc", {"enabled": False})

print("Source:", SRC_GLOB)

# --------- 2) BRONZE (with your VALIDATED transforms) ---------
bronze_df = (spark.read.format("csv")
             .option("header", True)
             .option("multiLine", True)
             .option("escape", '"')
             .load(SRC_GLOB))
bronze_df = bronze_df.withColumn("date", F.to_date("date", "M/d/yyyy"))
bronze_df = bronze_df.withColumn("price_unit", bronze_df["price_unit"].cast(T.DoubleType()))
bronze_df = bronze_df.withColumn("promotion_flag", bronze_df["promotion_flag"].cast(T.BooleanType()))
bronze_df = bronze_df.withColumn("region", F.initcap(F.col("region")))

# (Optional) Keep a raw snapshot table; else this is your bronze canonical
bronze_df.write.mode("overwrite").format("delta").saveAsTable(TBL["bronze_raw"])
print(f"Bronze written: {TBL['bronze_raw']}")

# --------- 3) SILVER (incremental + standardization) ---------
# Additional standardization: dedupe, strict types for other numeric columns
incoming = (bronze_df
    .dropDuplicates(["date", "sku", "channel"])
    .withColumn("units_sold",      F.col("units_sold").cast(T.LongType()))
    .withColumn("delivered_qty",   F.col("delivered_qty").cast(T.LongType()))
    .withColumn("delivery_days",   F.col("delivery_days").cast(T.DoubleType()))
    .withColumn("stock_available", F.col("stock_available").cast(T.LongType()))
)

# Watermark read
wm_df = spark.table(TBL["meta_watermark"]).filter(F.col("table_name")==TBL["silver_clean"])
maxdt = wm_df.select("max_date").head()["max_date"] if wm_df.count() else None

df_new = incoming if maxdt is None else incoming.filter(F.col("date") > F.lit(maxdt))
rows = df_new.count()

if rows > 0:
    # Append to Silver
    df_new.write.mode("append").format("delta").saveAsTable(TBL["silver_clean"])

    # Update watermark & audit with typed DataFrames (no string 'DATE' parsing issues)
    stats = df_new.agg(F.min("date").alias("min_d"), F.max("date").alias("max_d")).collect()[0]
    min_d, max_d = stats["min_d"], stats["max_d"]

    spark.sql(f"DELETE FROM {TBL['meta_watermark']} WHERE table_name='{TBL['silver_clean']}'")
    spark.createDataFrame([(TBL["silver_clean"], max_d)],
                          "table_name string, max_date date"
    ).write.mode("append").saveAsTable(TBL["meta_watermark"])

    spark.createDataFrame(
        [(datetime.now(), SRC_GLOB, rows, rows, min_d, max_d)],
        "run_ts timestamp, source_glob string, rows_in long, rows_out long, min_date date, max_date date"
    ).write.mode("append").saveAsTable(TBL["meta_audit"])

print(f"Silver incremental done | rows appended: {rows}")

# --------- 4) DQ (optional but recommended) ---------
if DQ_OPTS.get("enforce_negative_checks", True) or DQ_OPTS.get("enforce_null_key_checks", True):
    s = spark.table(TBL["silver_clean"])
    dq = (s.select("date","sku","region","channel","units_sold","delivered_qty","price_unit","stock_available","promotion_flag")
      .withColumn("err_negative_units", (F.col("units_sold") < 0).cast("int"))
      .withColumn("err_negative_delivered", (F.col("delivered_qty") < 0).cast("int"))
      .withColumn("err_null_key", (
            F.col("sku").isNull() | F.col("region").isNull() | F.col("price_unit").isNull()
      ).cast("int"))
      .withColumn("err_delivered_lt_sold", (F.col("delivered_qty") < F.col("units_sold")).cast("int"))
      .withColumn("error_flag",
            F.col("err_negative_units")+F.col("err_negative_delivered")+F.col("err_null_key")+F.col("err_delivered_lt_sold"))
      .filter("error_flag > 0")
    )
    dq.write.mode("overwrite").format("delta").saveAsTable(TBL["dq_errors"])
    print(f"DQ table written: {TBL['dq_errors']}")

# --------- 5) GOLD (promo, aggregates, BI fact) ---------
# Promo table: promotion_flag is BOOLEAN now -> compare to True
promo_df = (spark.table(TBL["silver_clean"])
    .filter((F.col("promotion_flag") == F.lit(True)) & (F.col("price_unit") > F.lit(8.0)))
    .withColumn("effective_price", F.col("price_unit") * F.lit(0.9))
)
promo_df.write.mode("overwrite").format("delta").saveAsTable(TBL["silver_promo"])

# Region × SKU aggregation
agg_df = (spark.table(TBL["silver_clean"])
  .groupBy("region","sku")
  .agg(
    F.sum("units_sold").alias("total_units_sold"),
    F.sum("delivered_qty").alias("total_delivered_qty"),
    F.avg("delivery_days").alias("avg_delivery_days"),
    (F.sum("units_sold")/F.sum("stock_available")).alias("stock_utilization")
  )
)
agg_df.write.mode("overwrite").format("delta").saveAsTable(TBL["gold_region_sku"])

# Category contribution
cat_df = (spark.table(TBL["silver_clean"])
  .groupBy("segment","category","pack_type")
  .agg(
     F.sum("units_sold").alias("total_sales_units"),
     F.sum("stock_available").alias("total_stock")
  )
  .withColumn("sell_through_rate",
              F.when(F.col("total_stock")>0, F.col("total_sales_units")/F.col("total_stock"))
               .otherwise(F.lit(None)))
)
cat_df.write.mode("overwrite").format("delta").saveAsTable(TBL["gold_category"])

# Final BI fact (Direct Lake)
sales = spark.table(TBL["silver_clean"])
promo = spark.table(TBL["silver_promo"])

gold_df = (sales.alias("s")
  .join(
      promo.select("date","sku","channel","effective_price").alias("p"),
      on=["date","sku","channel"], how="left"
  )
  .withColumn("unit_price_final", F.coalesce(F.col("p.effective_price"), F.col("s.price_unit")))
  .withColumn("line_revenue", F.col("s.units_sold") * F.col("unit_price_final"))
  .groupBy("s.date","s.region","s.brand","s.sku")
  .agg(
     F.sum("s.units_sold").alias("total_units_sold"),
     F.sum("line_revenue").alias("total_revenue"),
     (F.sum("s.units_sold")/F.sum("s.stock_available")).alias("stock_utilization"),
     F.max("s.promotion_flag").alias("promotion_flag")
  )
)

(gold_df.write
   .mode("overwrite")
   .format("delta")
   .partitionBy("date")
   .saveAsTable(TBL["gold_bi"]))
print(f"Gold written: {TBL['gold_bi']}")
# ================== end notebook ==================
