# ML Enhancement: Stockout Prediction

Predicts stockout risk using LightGBM to identify store/product combinations at risk of running out of stock.

## Business Value
- Proactive replenishment before stockouts occur
- Reduce lost sales from out-of-stock items
- Optimize safety stock levels
- Prioritize expedited shipping decisions

## Model Details
- **Algorithm:** LightGBM Classifier with class imbalance handling
- **Target:** Binary stockout in next 3 days
- **Features:** Current inventory, demand velocity, lead time, day of week, trends, seasonality
- **Performance Target:** Recall > 0.8, Precision > 0.5

## Data Flow
```
Silver (fact_store_inventory_txn, fact_receipt_lines, dim_products) --> Feature Engineering --> LightGBM --> gold_stockout_risk
```

## Schedule
Run this notebook **daily** to generate updated stockout risk scores.

In [None]:
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import StructType, StructField, LongType, DoubleType, IntegerType, TimestampType, StringType
from datetime import datetime, timedelta, timezone
import os
import pandas as pd
import numpy as np
from lightgbm import LGBMClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score

In [None]:
# =============================================================================
# PARAMETERS
# =============================================================================

def get_env(var_name, default=None):
    return os.environ.get(var_name, default)

SILVER_DB = get_env("SILVER_DB", default="ag")
GOLD_DB = get_env("GOLD_DB", default="au")

# Model parameters
FORECAST_HORIZON_DAYS = 3  # Predict stockouts in next 3 days
LOOKBACK_DAYS = 30  # Use last 30 days for feature engineering
STOCKOUT_THRESHOLD = 0  # Balance <= 0 indicates stockout

# LightGBM parameters tuned for class imbalance
MODEL_PARAMS = {
    'objective': 'binary',
    'metric': 'binary_logloss',
    'boosting_type': 'gbdt',
    'num_leaves': 31,
    'learning_rate': 0.05,
    'n_estimators': 200,
    'class_weight': 'balanced',  # Handle class imbalance
    'min_child_samples': 20,
    'subsample': 0.8,
    'colsample_bytree': 0.8,
    'random_state': 42,
    'verbose': -1
}

print(f"Configuration:")
print(f"  SILVER_DB: {SILVER_DB}")
print(f"  GOLD_DB: {GOLD_DB}")
print(f"  Forecast Horizon: {FORECAST_HORIZON_DAYS} days")
print(f"  Lookback Period: {LOOKBACK_DAYS} days")
print()

In [None]:
# =============================================================================
# HELPER FUNCTIONS
# =============================================================================

def read_silver(table_name):
    return spark.table(f"{SILVER_DB}.{table_name}")

def save_gold(df, table_name):
    full_name = f"{GOLD_DB}.{table_name}"
    df.write.format("delta").mode("overwrite").saveAsTable(full_name)
    print(f"  {full_name}: {df.count()} rows")

## Step 1: Calculate Current Inventory Position and Historical Demand

In [None]:
print("="*80)
print("EXTRACTING INVENTORY AND DEMAND DATA")
print("="*80)
print()

# Get current inventory position per store/product
print("Loading current inventory position...")
inventory_df = (
    read_silver("fact_store_inventory_txn")
    .select("store_id", "product_id", "event_ts", "balance", "delta")
)

# Get latest inventory snapshot per store/product
window_latest = Window.partitionBy("store_id", "product_id").orderBy(F.desc("event_ts"))
current_inventory = (
    inventory_df
    .withColumn("rn", F.row_number().over(window_latest))
    .filter(F.col("rn") == 1)
    .select(
        "store_id",
        "product_id",
        F.col("event_ts").alias("inventory_as_of"),
        F.col("balance").alias("current_inventory")
    )
)

print(f"  Current inventory records: {current_inventory.count()}")
print()

# Calculate demand velocity from sales (last 7, 14, 30 days)
print("Calculating demand velocity from sales...")
sales_df = (
    read_silver("fact_receipt_lines")
    .join(
        read_silver("fact_receipts").select("receipt_id_ext", "store_id", "event_ts"),
        on="receipt_id_ext",
        how="inner"
    )
    .select("store_id", "product_id", "quantity", "event_ts")
)

# Calculate demand for different time windows
sales_df = sales_df.withColumn("days_ago", F.datediff(F.current_timestamp(), F.col("event_ts")))

demand_velocity = (
    sales_df
    .groupBy("store_id", "product_id")
    .agg(
        F.sum(F.when(F.col("days_ago") <= 7, F.col("quantity")).otherwise(0)).alias("demand_7d"),
        F.sum(F.when(F.col("days_ago") <= 14, F.col("quantity")).otherwise(0)).alias("demand_14d"),
        F.sum(F.when(F.col("days_ago") <= 30, F.col("quantity")).otherwise(0)).alias("demand_30d"),
        F.avg(F.when(F.col("days_ago") <= 30, F.col("quantity")).otherwise(None)).alias("avg_order_size_30d")
    )
    .withColumn("demand_velocity_daily", F.col("demand_30d") / 30.0)
)

print(f"  Demand velocity records: {demand_velocity.count()}")
print()

## Step 2: Feature Engineering

In [None]:
print("="*80)
print("FEATURE ENGINEERING")
print("="*80)
print()

# Join inventory with demand
features_df = (
    current_inventory
    .join(demand_velocity, on=["store_id", "product_id"], how="left")
)

# Load product dimension for reorder point and lead time
print("Loading product attributes...")
products_df = (
    read_silver("dim_products")
    .select("ID", "Department", "Category", "Subcategory")
    .withColumnRenamed("ID", "product_id")
)

features_df = features_df.join(products_df, on="product_id", how="left")

# Add time-based features
print("Adding time-based features...")
features_df = (
    features_df
    .withColumn("day_of_week", F.dayofweek(F.col("inventory_as_of")))
    .withColumn("day_of_month", F.dayofmonth(F.col("inventory_as_of")))
    .withColumn("week_of_year", F.weekofyear(F.col("inventory_as_of")))
    .withColumn("is_weekend", F.when(F.dayofweek(F.col("inventory_as_of")).isin([1, 7]), 1).otherwise(0))
)

# Calculate derived features
print("Calculating derived features...")
features_df = (
    features_df
    # Fill nulls in demand metrics with 0
    .fillna(0, subset=["demand_7d", "demand_14d", "demand_30d", "demand_velocity_daily", "avg_order_size_30d"])
    # Days of inventory remaining (current inventory / daily demand)
    .withColumn(
        "days_of_inventory",
        F.when(F.col("demand_velocity_daily") > 0, F.col("current_inventory") / F.col("demand_velocity_daily")).otherwise(999)
    )
    # Demand trend (comparing 7d vs 30d normalized demand)
    .withColumn(
        "demand_trend",
        F.when(
            F.col("demand_30d") > 0,
            (F.col("demand_7d") / 7.0) / (F.col("demand_30d") / 30.0)
        ).otherwise(1.0)
    )
    # Stockout risk score (simple heuristic: low inventory + high demand)
    .withColumn(
        "inventory_ratio",
        F.when(F.col("demand_30d") > 0, F.col("current_inventory") / F.col("demand_30d")).otherwise(1.0)
    )
)

print(f"  Total feature records: {features_df.count()}")
print()

# Show sample features
print("Sample features:")
features_df.select(
    "store_id", "product_id", "current_inventory", "demand_velocity_daily",
    "days_of_inventory", "demand_trend", "day_of_week"
).show(10)

## Step 3: Create Training Labels (Historical Stockouts)

In [None]:
print("="*80)
print("CREATING TRAINING LABELS")
print("="*80)
print()

# Find historical stockout events
print(f"Identifying stockouts (balance <= {STOCKOUT_THRESHOLD})...")

# Get all inventory transactions and identify stockouts
inventory_history = (
    read_silver("fact_store_inventory_txn")
    .select("store_id", "product_id", "event_ts", "balance")
    .withColumn("is_stockout", F.when(F.col("balance") <= STOCKOUT_THRESHOLD, 1).otherwise(0))
)

# For each point in time, look forward N days to see if stockout occurred
window_forward = Window.partitionBy("store_id", "product_id").orderBy("event_ts")

labeled_data = (
    inventory_history
    .withColumn("next_event_ts", F.lead("event_ts", 1).over(window_forward))
    .withColumn("next_balance", F.lead("balance", 1).over(window_forward))
    .withColumn(
        "days_to_next_event",
        F.when(
            F.col("next_event_ts").isNotNull(),
            F.datediff(F.col("next_event_ts"), F.col("event_ts"))
        ).otherwise(999)
    )
    .withColumn(
        "stockout_in_3_days",
        F.when(
            (F.col("days_to_next_event") <= FORECAST_HORIZON_DAYS) & (F.col("next_balance") <= STOCKOUT_THRESHOLD),
            1
        ).otherwise(0)
    )
)

# Join labels with features based on timestamp alignment
training_data = (
    labeled_data
    .select("store_id", "product_id", "event_ts", "balance", "stockout_in_3_days")
    .join(
        features_df,
        on=["store_id", "product_id"],
        how="inner"
    )
    .filter(F.col("event_ts") <= F.current_timestamp() - F.expr(f"INTERVAL {FORECAST_HORIZON_DAYS} DAYS"))
)

print(f"  Training records: {training_data.count()}")
print()

# Check class distribution
print("Class distribution:")
training_data.groupBy("stockout_in_3_days").count().show()

## Step 4: Train LightGBM Model

In [None]:
print("="*80)
print("TRAINING LIGHTGBM MODEL")
print("="*80)
print()

# Convert to Pandas for sklearn/LightGBM
print("Converting to Pandas...")
training_pd = training_data.toPandas()

# Define feature columns
feature_cols = [
    "current_inventory",
    "demand_7d",
    "demand_14d",
    "demand_30d",
    "demand_velocity_daily",
    "avg_order_size_30d",
    "days_of_inventory",
    "demand_trend",
    "inventory_ratio",
    "day_of_week",
    "day_of_month",
    "week_of_year",
    "is_weekend"
]

# Handle any remaining nulls
training_pd[feature_cols] = training_pd[feature_cols].fillna(0)

# Split features and target
X = training_pd[feature_cols]
y = training_pd["stockout_in_3_days"]

print(f"  Features shape: {X.shape}")
print(f"  Target distribution: {y.value_counts().to_dict()}")
print()

# Train/test split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

print(f"  Training set: {X_train.shape[0]} samples")
print(f"  Test set: {X_test.shape[0]} samples")
print()

# Train LightGBM classifier
print("Training LightGBM classifier...")
model = LGBMClassifier(**MODEL_PARAMS)
model.fit(X_train, y_train)

print("  Training complete!")
print()

## Step 5: Model Evaluation

In [None]:
print("="*80)
print("MODEL EVALUATION")
print("="*80)
print()

# Predictions
y_pred = model.predict(X_test)
y_pred_proba = model.predict_proba(X_test)[:, 1]

# Classification report
print("Classification Report:")
print(classification_report(y_test, y_pred, target_names=["No Stockout", "Stockout"]))
print()

# Confusion matrix
print("Confusion Matrix:")
cm = confusion_matrix(y_test, y_pred)
print(cm)
print()

# Calculate key metrics
tn, fp, fn, tp = cm.ravel()
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

print(f"Precision: {precision:.3f}")
print(f"Recall: {recall:.3f}")
print(f"F1-Score: {f1:.3f}")
print()

# ROC AUC
try:
    roc_auc = roc_auc_score(y_test, y_pred_proba)
    print(f"ROC AUC: {roc_auc:.3f}")
except:
    print("ROC AUC: Unable to compute (may need more positive samples)")
print()

# Check acceptance criteria
print("="*80)
print("ACCEPTANCE CRITERIA CHECK")
print("="*80)
print(f"Recall > 0.8: {'✓ PASS' if recall > 0.8 else '✗ FAIL'} (actual: {recall:.3f})")
print(f"Precision > 0.5: {'✓ PASS' if precision > 0.5 else '✗ FAIL'} (actual: {precision:.3f})")
print()

# Feature importance
print("Top 10 Most Important Features:")
feature_importance = pd.DataFrame({
    'feature': feature_cols,
    'importance': model.feature_importances_
}).sort_values('importance', ascending=False)

print(feature_importance.head(10).to_string(index=False))
print()

## Step 6: Generate Predictions for Current Inventory

In [None]:
print("="*80)
print("GENERATING STOCKOUT RISK PREDICTIONS")
print("="*80)
print()

# Prepare current data for prediction
current_pd = features_df.toPandas()
current_pd[feature_cols] = current_pd[feature_cols].fillna(0)

# Predict stockout probability
current_pd['stockout_probability'] = model.predict_proba(current_pd[feature_cols])[:, 1]
current_pd['stockout_predicted'] = model.predict(current_pd[feature_cols])

# Add risk categories
def categorize_risk(prob):
    if prob >= 0.7:
        return "HIGH"
    elif prob >= 0.4:
        return "MEDIUM"
    else:
        return "LOW"

current_pd['risk_level'] = current_pd['stockout_probability'].apply(categorize_risk)

# Add prediction timestamp
current_pd['predicted_at'] = pd.Timestamp.now(tz=timezone.utc)
current_pd['forecast_horizon_days'] = FORECAST_HORIZON_DAYS

print(f"  Predictions generated: {len(current_pd)}")
print()

# Show risk distribution
print("Risk Distribution:")
print(current_pd['risk_level'].value_counts())
print()

# Show top risks
print("Top 20 Stockout Risks:")
top_risks = current_pd.nlargest(20, 'stockout_probability')[[
    'store_id', 'product_id', 'current_inventory', 'demand_velocity_daily',
    'days_of_inventory', 'stockout_probability', 'risk_level'
]]
print(top_risks.to_string(index=False))
print()

## Step 7: Save to Gold Layer

In [None]:
print("="*80)
print("SAVING TO GOLD LAYER")
print("="*80)
print()

# Select output columns
output_cols = [
    'store_id',
    'product_id',
    'current_inventory',
    'demand_velocity_daily',
    'days_of_inventory',
    'demand_trend',
    'stockout_probability',
    'stockout_predicted',
    'risk_level',
    'predicted_at',
    'forecast_horizon_days',
    'Department',
    'Category',
    'Subcategory'
]

output_pd = current_pd[output_cols]

# Convert back to Spark DataFrame
output_schema = StructType([
    StructField("store_id", LongType(), False),
    StructField("product_id", LongType(), False),
    StructField("current_inventory", LongType(), True),
    StructField("demand_velocity_daily", DoubleType(), True),
    StructField("days_of_inventory", DoubleType(), True),
    StructField("demand_trend", DoubleType(), True),
    StructField("stockout_probability", DoubleType(), False),
    StructField("stockout_predicted", IntegerType(), False),
    StructField("risk_level", StringType(), False),
    StructField("predicted_at", TimestampType(), False),
    StructField("forecast_horizon_days", IntegerType(), False),
    StructField("Department", StringType(), True),
    StructField("Category", StringType(), True),
    StructField("Subcategory", StringType(), True)
])

output_spark_df = spark.createDataFrame(output_pd, schema=output_schema)

# Save to gold layer
print("Saving to gold_stockout_risk...")
save_gold(output_spark_df, "gold_stockout_risk")
print()

print("="*80)
print("STOCKOUT PREDICTION COMPLETE")
print("="*80)
print()
print("Output table: {}.gold_stockout_risk".format(GOLD_DB))
print("Next Steps:")
print("  1. Create Power BI dashboard for replenishment alerts")
print("  2. Set up notifications for HIGH risk items")
print("  3. Schedule this notebook to run daily")
print("  4. Monitor prediction accuracy and retrain as needed")