In [0]:
# 05_business_decision_layer

**Purpose:**  
Transform ML predictions into actionable business decisions by generating  
delay risk scores, risk categories, and recommended actions.

**Layer:** Gold  
**Output Table:** logistics_gold.delivery_risk_decisions

In [0]:
%python
from pyspark.sql.functions import col, when
from pyspark.ml.functions import vector_to_array
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import RandomForestClassificationModel

In [0]:
ml_df = spark.table("logistics_gold.ml_delay_dataset")

In [0]:
feature_cols = [
    "distance_km",
    "shipment_weight_kg",
    "avg_traffic_index",
    "weather_severity_score",
    "historical_delay_rate",
    "avg_delay_days",
    "is_weekend",
    "is_peak_season"
]

assembler = VectorAssembler(
    inputCols=feature_cols,
    outputCol="features",
    handleInvalid="skip"
)

feature_df = assembler.transform(ml_df).select(
    "shipment_id",
    "features",
    "is_late"
)

In [0]:
rf_model = RandomForestClassificationModel.load(
    "/Volumes/workspace/default/ml_models/random_forest_delay_model"
)


In [0]:
predictions_df = rf_model.transform(feature_df)

In [0]:
scored_df = predictions_df.withColumn(
    "delay_risk_score",
    vector_to_array(col("probability"))[1]
)

In [0]:
decision_df = scored_df.withColumn(
    "risk_category",
    when(col("delay_risk_score") < 0.30, "Low")
    .when(col("delay_risk_score") < 0.70, "Medium")
    .otherwise("High")
)

In [0]:
decision_df = decision_df.withColumn(
    "recommended_action",
    when(col("risk_category") == "High", "Reroute shipment / Alert operations")
    .when(col("risk_category") == "Medium", "Monitor shipment closely")
    .otherwise("No action required")
)

In [0]:
final_business_df = decision_df.select(
    "shipment_id",
    "delay_risk_score",
    "risk_category",
    "recommended_action",
    "is_late"
)

In [0]:
final_business_df.write \
    .format("delta") \
    .mode("overwrite") \
    .option("overwriteSchema", "true") \
    .saveAsTable("logistics_gold.delivery_risk_decisions")

In [0]:
%sql
SELECT risk_category, COUNT(*) 
FROM logistics_gold.delivery_risk_decisions
GROUP BY risk_category;

risk_category,COUNT(*)
Medium,9517
Low,483


In [0]:
%sql
SELECT
  risk_category,
  AVG(is_late) AS actual_delay_rate
FROM logistics_gold.delivery_risk_decisions
GROUP BY risk_category;

risk_category,actual_delay_rate
Medium,0.3695492276978039
Low,0.0538302277432712
