# ML Demand Forecasting with Prophet

Generates 14-day demand forecasts by store and product using Facebook Prophet.

## Data Flow
```
fact_receipt_lines (Silver) --> gold_demand_forecast (Gold)
```

## Model Details
- **Algorithm:** Facebook Prophet (additive time series)
- **Granularity:** Store × Product × Day
- **Forecast horizon:** 14 days
- **Features:** Historical sales with daily/weekly seasonality
- **Target metric:** MAPE < 25% for top 80% products by volume

## Usage
Schedule this notebook to run **daily at 6 AM** via Fabric pipeline to ensure fresh forecasts.

In [None]:
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.utils import AnalysisException
from datetime import datetime, timezone, timedelta
import pandas as pd
import os

# Prophet import
try:
    from prophet import Prophet
except ImportError:
    print("Installing prophet...")
    !pip install prophet
    from prophet import Prophet

In [None]:
# =============================================================================
# PARAMETERS
# =============================================================================

def get_env(var_name, default=None):
    return os.environ.get(var_name, default)

SILVER_DB = get_env("SILVER_DB", default="ag")
GOLD_DB = get_env("GOLD_DB", default="au")

# Forecasting parameters
FORECAST_HORIZON_DAYS = 14
MIN_HISTORY_DAYS = 30  # Minimum historical data required
MIN_DAILY_SALES = 0.5  # Minimum average daily sales to forecast
TARGET_MAPE = 25.0  # Target MAPE threshold

print(f"Configuration: SILVER_DB={SILVER_DB}, GOLD_DB={GOLD_DB}")
print(f"Forecast horizon: {FORECAST_HORIZON_DAYS} days")
print(f"Minimum history: {MIN_HISTORY_DAYS} days")
print(f"Target MAPE: {TARGET_MAPE}%")

In [None]:
# =============================================================================
# HELPER FUNCTIONS
# =============================================================================

def ensure_database(name):
    spark.sql(f"CREATE DATABASE IF NOT EXISTS {name}")

def read_silver(table_name):
    return spark.table(f"{SILVER_DB}.{table_name}")

def silver_exists(table_name):
    try:
        spark.table(f"{SILVER_DB}.{table_name}")
        return True
    except AnalysisException:
        return False

ensure_database(GOLD_DB)

In [None]:
print("="*60)
print("DATA PREPARATION")
print("="*60)

In [None]:
# Aggregate historical sales by store, product, and day
if not silver_exists("fact_receipt_lines"):
    raise RuntimeError("fact_receipt_lines table not found in Silver layer")

print("Reading fact_receipt_lines and joining with receipts...")

# Join receipt lines with receipts to get store_id
df_lines = read_silver("fact_receipt_lines")
df_receipts = read_silver("fact_receipts")

# Aggregate to daily sales by store and product
df_daily_sales = (
    df_lines
    .join(df_receipts.select("receipt_id_ext", "store_id"), on="receipt_id_ext")
    .withColumn("sale_date", F.to_date("event_ts"))
    .groupBy("store_id", "product_id", "sale_date")
    .agg(
        F.sum("quantity").alias("units_sold"),
        F.sum("ext_price").alias("revenue")
    )
)

print(f"Daily sales aggregated: {df_daily_sales.count()} rows")

# Filter to store-product combinations with sufficient history
df_history_check = (
    df_daily_sales
    .groupBy("store_id", "product_id")
    .agg(
        F.min("sale_date").alias("first_sale"),
        F.max("sale_date").alias("last_sale"),
        F.count("sale_date").alias("days_with_sales"),
        F.avg("units_sold").alias("avg_daily_sales")
    )
    .withColumn(
        "history_days",
        F.datediff(F.col("last_sale"), F.col("first_sale"))
    )
    .filter(
        (F.col("history_days") >= MIN_HISTORY_DAYS) &
        (F.col("avg_daily_sales") >= MIN_DAILY_SALES)
    )
)

print(f"Store-product combinations with sufficient history: {df_history_check.count()}")

# Join back to keep only qualifying combinations
df_training_data = (
    df_daily_sales
    .join(
        df_history_check.select("store_id", "product_id"),
        on=["store_id", "product_id"],
        how="inner"
    )
)

print(f"Training data prepared: {df_training_data.count()} rows")

In [None]:
print("\n" + "="*60)
print("MODEL TRAINING & FORECASTING")
print("="*60)

In [None]:
def train_and_forecast(store_id, product_id, sales_data):
    """
    Train Prophet model and generate forecasts for a store-product combination.
    
    Args:
        store_id: Store identifier
        product_id: Product identifier
        sales_data: List of tuples (sale_date, units_sold)
    
    Returns:
        List of forecast records with confidence intervals
    """
    # Prepare data in Prophet format
    df_prophet = pd.DataFrame(sales_data, columns=["ds", "y"])
    
    # Handle edge cases
    if len(df_prophet) < 2:
        return []
    
    try:
        # Initialize and train model
        model = Prophet(
            daily_seasonality=True,
            weekly_seasonality=True,
            yearly_seasonality=False,  # Not enough history
            changepoint_prior_scale=0.05,  # Less sensitive to changes
            seasonality_prior_scale=10.0,
            interval_width=0.95  # 95% confidence intervals
        )
        
        # Suppress Prophet logs
        import logging
        logging.getLogger('prophet').setLevel(logging.ERROR)
        
        model.fit(df_prophet)
        
        # Generate future dates
        future = model.make_future_dataframe(periods=FORECAST_HORIZON_DAYS)
        forecast = model.predict(future)
        
        # Calculate MAPE on historical data
        historical_forecast = forecast[forecast['ds'].isin(df_prophet['ds'])]
        merged = df_prophet.merge(historical_forecast[['ds', 'yhat']], on='ds')
        
        # Avoid division by zero
        merged = merged[merged['y'] > 0]
        if len(merged) > 0:
            mape = (abs(merged['y'] - merged['yhat']) / merged['y']).mean() * 100
        else:
            mape = None
        
        # Extract future forecasts only
        future_forecast = forecast[forecast['ds'] > df_prophet['ds'].max()]
        
        # Build result records
        results = []
        for _, row in future_forecast.iterrows():
            results.append((
                store_id,
                product_id,
                row['ds'].date(),
                max(0.0, float(row['yhat'])),  # Predicted units (non-negative)
                max(0.0, float(row['yhat_lower'])),  # Lower bound
                max(0.0, float(row['yhat_upper'])),  # Upper bound
                float(mape) if mape is not None else None,
                pd.Timestamp.utcnow()
            ))
        
        return results
        
    except Exception as e:
        # Log error and return empty list
        print(f"Error forecasting store_id={store_id}, product_id={product_id}: {str(e)}")
        return []

print("Prophet forecasting function defined")

In [None]:
# Collect training data for each store-product combination
# Note: For large datasets, consider sampling or processing in batches
print("Collecting training data by store and product...")

df_grouped = (
    df_training_data
    .groupBy("store_id", "product_id")
    .agg(
        F.collect_list(
            F.struct("sale_date", "units_sold")
        ).alias("sales_history")
    )
)

store_product_data = df_grouped.collect()
print(f"Processing {len(store_product_data)} store-product combinations...")

# Train models and generate forecasts
all_forecasts = []
success_count = 0
error_count = 0

for i, row in enumerate(store_product_data):
    if i % 50 == 0:
        print(f"  Progress: {i}/{len(store_product_data)} combinations processed")
    
    store_id = row['store_id']
    product_id = row['product_id']
    sales_history = [(s['sale_date'], float(s['units_sold'])) for s in row['sales_history']]
    
    forecasts = train_and_forecast(store_id, product_id, sales_history)
    
    if forecasts:
        all_forecasts.extend(forecasts)
        success_count += 1
    else:
        error_count += 1

print(f"\nForecasting complete:")
print(f"  Successful: {success_count}")
print(f"  Failed: {error_count}")
print(f"  Total forecast records: {len(all_forecasts)}")

In [None]:
print("\n" + "="*60)
print("SAVE FORECASTS TO GOLD LAYER")
print("="*60)

In [None]:
# Convert forecasts to DataFrame
if all_forecasts:
    schema = [
        "store_id",
        "product_id",
        "forecast_date",
        "predicted_units",
        "lower_bound",
        "upper_bound",
        "mape",
        "generated_at"
    ]
    
    df_forecasts = spark.createDataFrame(all_forecasts, schema=schema)
    
    # Cast types explicitly
    df_forecasts = (
        df_forecasts
        .withColumn("store_id", F.col("store_id").cast("long"))
        .withColumn("product_id", F.col("product_id").cast("long"))
        .withColumn("forecast_date", F.col("forecast_date").cast("date"))
        .withColumn("predicted_units", F.col("predicted_units").cast("double"))
        .withColumn("lower_bound", F.col("lower_bound").cast("double"))
        .withColumn("upper_bound", F.col("upper_bound").cast("double"))
        .withColumn("mape", F.col("mape").cast("double"))
        .withColumn("generated_at", F.col("generated_at").cast("timestamp"))
    )
    
    # Save to Gold layer
    table_name = f"{GOLD_DB}.gold_demand_forecast"
    df_forecasts.write.format("delta").mode("overwrite").saveAsTable(table_name)
    
    print(f"Saved {df_forecasts.count()} forecast records to {table_name}")
    
    # Display sample forecasts
    print("\nSample forecasts:")
    df_forecasts.orderBy("store_id", "product_id", "forecast_date").show(10, truncate=False)
    
else:
    print("No forecasts generated. Check data quality and parameters.")

In [None]:
print("\n" + "="*60)
print("FORECAST ACCURACY METRICS")
print("="*60)

In [None]:
# Calculate summary statistics
if all_forecasts:
    df_metrics = spark.table(f"{GOLD_DB}.gold_demand_forecast")
    
    # Overall MAPE distribution
    print("MAPE Statistics:")
    df_metrics.select("mape").filter(F.col("mape").isNotNull()).describe().show()
    
    # Count products meeting target MAPE
    total_combinations = df_metrics.select("store_id", "product_id").distinct().count()
    meeting_target = (
        df_metrics
        .select("store_id", "product_id", "mape")
        .filter(F.col("mape").isNotNull())
        .groupBy("store_id", "product_id")
        .agg(F.avg("mape").alias("avg_mape"))
        .filter(F.col("avg_mape") < TARGET_MAPE)
        .count()
    )
    
    pct_meeting_target = (meeting_target / total_combinations * 100) if total_combinations > 0 else 0
    
    print(f"\nAccuracy Summary:")
    print(f"  Total store-product combinations: {total_combinations}")
    print(f"  Combinations meeting MAPE < {TARGET_MAPE}%: {meeting_target} ({pct_meeting_target:.1f}%)")
    
    if pct_meeting_target >= 80:
        print(f"\n  SUCCESS: {pct_meeting_target:.1f}% of products meet the accuracy target!")
    else:
        print(f"\n  WARNING: Only {pct_meeting_target:.1f}% meet target. Consider:")
        print("  - Increasing MIN_HISTORY_DAYS")
        print("  - Adjusting Prophet hyperparameters")
        print("  - Adding external regressors (promotions, holidays)")
    
    # Top 10 best and worst performers by MAPE
    print("\nTop 10 Most Accurate Forecasts (by avg MAPE):")
    (
        df_metrics
        .groupBy("store_id", "product_id")
        .agg(F.avg("mape").alias("avg_mape"))
        .filter(F.col("avg_mape").isNotNull())
        .orderBy("avg_mape")
        .show(10, truncate=False)
    )
    
    print("\nTop 10 Least Accurate Forecasts (by avg MAPE):")
    (
        df_metrics
        .groupBy("store_id", "product_id")
        .agg(F.avg("mape").alias("avg_mape"))
        .filter(F.col("avg_mape").isNotNull())
        .orderBy(F.desc("avg_mape"))
        .show(10, truncate=False)
    )

In [None]:
print("\n" + "="*60)
print("DEMAND FORECASTING COMPLETE")
print("="*60)

gold_tables = spark.sql(f"SHOW TABLES IN {GOLD_DB}").collect()
print(f"\nGold ({GOLD_DB}): {len(gold_tables)} tables")
print("\nForecast table: gold_demand_forecast")
print(f"Forecast horizon: {FORECAST_HORIZON_DAYS} days")
print("\nSchedule this notebook to run daily at 6 AM for fresh forecasts.")