In [0]:
from pyspark.sql.functions import *
from delta.tables import *

In [0]:
df = spark.read.table("inventory_project.bronze.wms_inventory_snapshot_raw")

In [0]:
df = df.select([trim(col(c)).alias(c) for c in df.columns])
df = df.withColumn("snapshot_date", to_date(col("snapshot_date"),"yyyy-MM-dd"))
df = df.withColumn("quantity_on_hand", col("quantity_on_hand").cast("int"))
df = df.withColumn(
        "status_clean",
          when(lower(col("status")) == "active", "Active")
         .when(lower(col("status")) == "damaged", "Damaged")
         .when(lower(col("status")) == "hold", "Hold")
         .otherwise("Unknown")
    )
key_col = "snapshot_id","product_id","bin_id","lot_id","serial_id"
df_dedup = df.dropDuplicates(key_col)

In [0]:
key_col = "snapshot_id","product_id","bin_id","lot_id","serial_id"
df_dedup = df.dropDuplicates(key_col)
df_valid = df_dedup.filter(
    col("product_id").isNotNull() &
    col("snapshot_date").isNotNull() &
    (col("quantity_on_hand") >= 0)
)

df_quarantine = df_dedup.filter(
    col("product_id").isNull() |
    col("snapshot_date").isNull() |
    (col("quantity_on_hand") < 0)
).withColumn("dq_reason", 
               when(col("product_id").isNull(), "Missing product_id")
              .when(col("snapshot_date").isNull(), "Missing snapshot_date")
              .when(col("quantity_on_hand") < 0, "Negative stock")
              .otherwise("Unknown"))

dt = DeltaTable.forName(spark, "inventory_project.silver.wms_inventory_snapshot")

dt.alias("t").merge(
    df_valid.alias("s"),
    "t.snapshot_id = s.snapshot_id AND t.product_id = s.product_id"
).whenMatchedUpdateAll().whenNotMatchedInsertAll().execute()
# df_valid.write.format("delta").mode("overwrite").saveAsTable("inventory_project.silver.wms_inventory_snapshot")
df_quarantine.write.format("csv").mode("overwrite").save("/Volumes/inventory_project/silver/quarantine_layer/wms_inventory_snapshot")
dbutils.notebook.exit("SUCCESS")