In [0]:
# Import necessary libraries
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, TimestampType, StringType, LongType
from pyspark.sql.utils import AnalysisException
from datetime import datetime

# COMMAND ----------

# DBTITLE 1,Configuration and Setup
# Define the source catalog and schema for the Silver layer
source_catalog = "silver"
source_schema = "e-commerce-sales"

# Define the target Unity Catalog and schema for the Gold layer
target_catalog = "gold"
target_schema = "e-commerce-sales"
target_table = "daily_sales_metrics"
audit_log_table = "audit_logs"


print(f"Source: '{source_catalog}.{source_schema}'")
print(f"Destination: '{target_catalog}.{target_schema}.{target_table}'")


# COMMAND ----------

# DBTITLE 1,Audit Logging Function
audit_log_table_full_name = f"{target_catalog}.`{target_schema}`.`{audit_log_table}`"

def log_audit(process_name, status, row_count=None, message=None):
    """
    Logs a record to the Gold layer audit_logs table.
    """
    try:
        log_schema = StructType([
            StructField("timestamp", TimestampType(), False),
            StructField("process_name", StringType(), False),
            StructField("status", StringType(), False),
            StructField("row_count", LongType(), True),
            StructField("message", StringType(), True)
        ])
        
        log_data = [(datetime.now(), process_name, status, row_count, message)]
        log_df = spark.createDataFrame(log_data, schema=log_schema)
        log_df.write.format("delta").mode("append").saveAsTable(audit_log_table_full_name)
        print(f"Logged '{status}' for process '{process_name}'.")

    except Exception as e:
        print(f"FATAL: Could not write to audit log table. Error: {e}")

# Log the start of the entire Gold job
log_audit("Gold Layer Job", "Started", message="Gold layer aggregation job initiated.")

# COMMAND ----------

# DBTITLE 1,Load Silver Tables
process_name = "Load Silver Tables"
try:
    log_audit(process_name, "Started")
    orders_df = spark.table(f"{source_catalog}.`{source_schema}`.orders")
    order_items_df = spark.table(f"{source_catalog}.`{source_schema}`.order_items")
    users_df = spark.table(f"{source_catalog}.`{source_schema}`.users")
    products_df = spark.table(f"{source_catalog}.`{source_schema}`.products")
    log_audit(process_name, "Success")
except Exception as e:
    log_audit(process_name, "Failed", message=str(e))
    dbutils.notebook.exit(f"Failed to load source Silver tables. Details: {e}")


# COMMAND ----------

# DBTITLE 1,Join and Enrich Data
process_name = "Enrichment and Joins"
try:
    log_audit(process_name, "Started")
    # Join orders with order_items on the order ID
    sales_df = orders_df.join(
        order_items_df,
        orders_df.order_id == order_items_df.order_id,
        "inner"
    ).select(
        orders_df.order_id,
        orders_df.created_at.alias("order_date"),
        orders_df.user_id,
        order_items_df.product_id,
        order_items_df.sale_price
    )

    # Join with products to get product details like cost and category
    sales_with_products_df = sales_df.join(
        products_df,
        sales_df.product_id == products_df.id,
        "left"
    ).select(
        sales_df["*"],
        products_df.cost,
        products_df.category.alias("product_category")
    )

    # Join with users to get customer details like state
    final_df = sales_with_products_df.join(
        users_df,
        sales_with_products_df.user_id == users_df.id,
        "left"
    ).select(
        F.to_date("order_date").alias("date"),
        sales_with_products_df["*"],
        users_df.state.alias("customer_state")
    )

    # Add a profit column
    final_df = final_df.withColumn("profit", F.col("sale_price") - F.col("cost"))
    log_audit(process_name, "Success")
except Exception as e:
    log_audit(process_name, "Failed", message=str(e))
    dbutils.notebook.exit(f"Failed during data enrichment. Details: {e}")


# COMMAND ----------

# DBTITLE 1,Aggregate Metrics for Gold Table
process_name = "Aggregation"
try:
    log_audit(process_name, "Started")
    daily_sales_metrics = final_df.groupBy("date", "customer_state", "product_category").agg(
        F.sum("sale_price").alias("total_revenue"),
        F.sum("cost").alias("total_cost"),
        F.sum("profit").alias("total_profit"),
        F.count("order_id").alias("items_sold_count"),
        F.countDistinct("order_id").alias("distinct_order_count"),
        F.countDistinct("product_id").alias("distinct_products_sold")
    ).cache() # Cache the result as it will be used for count and write

    # Get the final count for logging
    final_row_count = daily_sales_metrics.count()
    log_audit(process_name, "Success", row_count=final_row_count)
except Exception as e:
    log_audit(process_name, "Failed", message=str(e))
    dbutils.notebook.exit(f"Failed during aggregation. Details: {e}")

# COMMAND ----------

# DBTITLE 1,Write to Gold Layer
process_name = "Write to Gold Table"
gold_table_full_name = f"{target_catalog}.`{target_schema}`.`{target_table}`"

try:
    log_audit(process_name, "Started", row_count=final_row_count, message=f"Attempting to write to {gold_table_full_name}")
    daily_sales_metrics.write \
        .format("delta") \
        .mode("overwrite") \
        .option("overwriteSchema", "true") \
        .saveAsTable(gold_table_full_name)
    
    log_audit(process_name, "Success", row_count=final_row_count, message=f"Successfully wrote {final_row_count} rows.")

except Exception as e:
    log_audit(process_name, "Failed", message=str(e))
    dbutils.notebook.exit(f"Failed to write to Gold table. Details: {e}")

print("-" * 50)

# COMMAND ----------

# DBTITLE 1,Finalize Job
log_audit("Gold Layer Job", "Finished", message="Gold layer aggregation job completed successfully.")


# COMMAND ----------


# COMMAND ----------

print(f"Preview of the final Gold table: {gold_table_full_name}")
display(spark.table(gold_table_full_name).orderBy(F.desc("date"), F.desc("total_revenue")).limit(20))

# COMMAND ----------

print(f"Preview of the Gold audit log: {audit_log_table_full_name}")
display(spark.table(audit_log_table_full_name).orderBy(F.desc("timestamp")))
