# SmartSus Chef: The Universal Predictive Engine
**Version:** 3.0 (Production Ready) | **Context:** Singapore & China | **Architecture:** Champion-Challenger

## How to read this Notebook
This notebook runs a robust, parallelized ML pipeline to predict food demand. It acts as a high-level **orchestrator**, delegating the complex implementation details to the `training_logic.py` module.

### The End-to-End Workflow:
1.  **Context Detection:** Automatically determines the restaurant's location (e.g., SG or CN) to load the correct holiday and weather data.
2.  **Data Ingestion:** Fetches the raw sales history from a database or a fallback CSV file.
3.  **Data Cleaning & Sanitation:** Aggregates data to have one record per dish-day and fills any gaps from missing sales days.
4.  **Backtesting & Champion Selection:** For each dish, all three models (Prophet, CatBoost, XGBoost) are tuned using Optuna and evaluated via time-series cross-validation to find the "champion" with the lowest error.
5.  **Parallel Execution:** The entire evaluation for each dish is run in parallel on a separate CPU core to significantly speed up the training process.
6.  **Production Training:** After a champion model is chosen, all three models are retrained on 100% of the data using the best-found parameters and saved to disk.
7.  **Forecasting:** A 14-day rolling forecast is generated using the saved production models, complete with SHAP-based explanations for a model's reasoning.

In [None]:
# --- IMPORTS & SETUP ---
import numpy as np
import pandas as pd
import pickle
import holidays
import shap
from concurrent.futures import ProcessPoolExecutor, as_completed
from IPython.display import display
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

# Import core pipeline logic from our module
from training_logic import (
    PipelineConfig,
    CFG,
    process_dish,
    fetch_training_data,
    add_local_context,
    get_country_code,
    estimate_temperature,
    safe_filename,
)

print("Libraries loaded successfully.")

## Pipeline Overview: From Raw Data to Trained Models

The core of this project lies in the `training_logic.py` module, which is responsible for the heavy lifting. The notebook simply orchestrates the execution of this logic. The process for each dish is as follows:

### 1. Data Preparation & Feature Engineering
First, the pipeline prepares a clean, feature-rich dataset.
- **Context & Ingestion:** It begins by detecting the restaurant's location (`add_local_context`) and loading the complete sales history (`fetch_training_data`).
- **Cleaning & Enrichment:** The raw data is then cleaned to fill in missing dates (`sanitize_sparse_data`) and enriched with dozens of predictive features like lags, rolling averages, and trend indicators (`add_lag_features`). This provides rich input for the tree-based models.

### 2. Backtesting & Champion Selection
Next, each model type is rigorously tested to find the best one for the job.
- **Time-Series Cross-Validation:** Instead of a simple train-test split, we use expanding-window cross-validation (`_generate_cv_folds`), which is more robust for time-series data.
- **Hyperparameter Tuning:** For each model, `Optuna` is used to automatically find the best hyperparameters (e.g., learning rate, tree depth), ensuring peak performance.
- **Champion Crowning:** The model with the lowest Mean Absolute Error (MAE) across the validation folds is crowned the "champion" for that specific dish.

### 3. The `process_dish` Worker Function
The entire data preparation and backtesting process is encapsulated within the `process_dish` function in `training_logic.py`. This function acts as a self-contained, parallel-ready worker that takes a dish name and returns the champion model, its performance metrics, and its ideal parameters. The cell below runs this function for all dishes.

## Step 5: Multi-Day Prediction API (14-Day Rolling Forecast)
The prediction API now returns a **list of forecasts** for the next 14 days.

- **Tree models** use a **recursive forecasting loop**: each day's prediction is appended to the history to compute lags for the next day.
- **Prophet** uses its native `make_future_dataframe()` for multi-step prediction.
- **Average-only dishes** return a flat-line forecast at the saved average.
- SHAP explanations use **name-based feature grouping** via `config.feature_groups` for robustness.

In [None]:
# --- MODEL CACHE ---
_model_cache = {}

def _load_cached(filepath):
    if filepath not in _model_cache:
        with open(filepath, 'rb') as f:
            _model_cache[filepath] = pickle.load(f)
    return _model_cache[filepath]

def clear_model_cache():
    _model_cache.clear()


def _compute_lag_features_from_history(sales_history, dt, config):
    """Compute all lag/rolling features for a single forecast date given a sales history array."""
    vals = sales_history
    n = len(vals)

    features = {}
    for lag in [1, 7, 14, 21, 28]:
        features[f'lag_{lag}'] = vals[-lag] if n >= lag else 0.0

    if n >= 8:
        window = vals[-8:-1]  # shifted by 1
        features['rolling_mean_7'] = np.mean(window)
        features['rolling_std_7'] = np.std(window, ddof=1) if len(window) > 1 else 0.0
    else:
        features['rolling_mean_7'] = np.mean(vals) if n > 0 else 0.0
        features['rolling_std_7'] = 0.0

    if n >= 15:
        window = vals[-15:-1]
        features['rolling_mean_14'] = np.mean(window)
        features['rolling_std_14'] = np.std(window, ddof=1) if len(window) > 1 else 0.0
    else:
        features['rolling_mean_14'] = np.mean(vals) if n > 0 else 0.0
        features['rolling_std_14'] = 0.0

    rm7 = features['rolling_mean_7']
    rm14 = features['rolling_mean_14']
    features['trend_ratio'] = rm7 / rm14 if rm14 != 0 else 1.0

    if n >= 2:
        features['expanding_mean'] = np.mean(vals[:-1])
    else:
        features['expanding_mean'] = np.mean(vals) if n > 0 else 0.0

    weekday_vals = [vals[-s] for s in [7, 14, 21, 28] if n >= s]
    features['lag_same_weekday_avg'] = np.mean(weekday_vals) if weekday_vals else 0.0
    features['lag_same_weekday_std'] = np.std(weekday_vals, ddof=1) if len(weekday_vals) > 1 else 0.0

    return features


def _predict_tree_multiday(model_obj, base_future_row, recent_sales_df, config, country_code, dish_mae):
    """
    Recursive 14-day rolling forecast for tree-based models.
    Returns list of {date, qty, lower, upper, explanation} dicts.
    """
    results = []
    sales_history = recent_sales_df['sales'].values.tolist()
    start_date = base_future_row['ds'].iloc[0]

    # Build feature name -> group mapping from config
    feat_to_group = {}
    for group_name, feat_list in config.feature_groups.items():
        for feat in feat_list:
            feat_to_group[feat] = group_name

    for day_offset in range(config.forecast_horizon):
        dt = start_date + pd.Timedelta(days=day_offset)

        local_hols = holidays.SG(years=config.holiday_years) if country_code == 'SG' \
            else holidays.CN(years=config.holiday_years)
        is_hol = 1 if dt in local_hols else 0
        temp = estimate_temperature(dt, country_code)

        lag_feats = _compute_lag_features_from_history(sales_history, dt, config)

        row = {
            'day_of_week': dt.dayofweek,
            'month': dt.month,
            'is_public_holiday': is_hol,
            'rain_lunch_vol': base_future_row['rain_lunch_vol'].iloc[0],
            'temperature': temp,
        }
        row.update(lag_feats)

        future_df = pd.DataFrame([row])[config.tree_features]
        pred = float(model_obj.predict(future_df)[0])
        qty = int(max(0, pred))
        pred_lower = int(max(0, pred - dish_mae))
        pred_upper = int(pred + dish_mae)

        # SHAP explanation with name-based grouping
        try:
            ex = shap.TreeExplainer(model_obj)
            sv = ex.shap_values(future_df)[0]
            base_val = float(ex.expected_value)
            group_shap = {}
            for i, feat_name in enumerate(config.tree_features):
                group = feat_to_group.get(feat_name, "Other")
                group_shap[group] = group_shap.get(group, 0.0) + float(sv[i])

            expl = {
                "Trend": round(base_val + group_shap.get("Lags/Trend", 0.0), 1),
                "Seasonality": round(group_shap.get("Seasonality", 0.0), 1),
                "Holiday": round(group_shap.get("Holiday", 0.0), 1),
                "Weather": round(group_shap.get("Weather", 0.0), 1),
            }
        except Exception:
            expl = {"Trend": round(pred, 1), "Seasonality": 0.0,
                    "Holiday": 0.0, "Weather": 0.0}

        results.append({
            "date": dt.strftime('%Y-%m-%d'),
            "qty": qty,
            "lower": pred_lower,
            "upper": pred_upper,
            "explanation": expl
        })

        # Append prediction to history for next iteration
        sales_history.append(max(0, pred))

    return results


def get_prediction(dish, date_str, lat, lon, rain_forecast=0, model='auto', config=CFG):
    """
    Multi-day prediction API.
    Returns a list of dicts (one per forecast day, up to config.forecast_horizon days).
    For average-only dishes, returns flat-line forecast.
    """
    dt = pd.to_datetime(date_str)
    country = get_country_code(lat, lon)
    safe_name = safe_filename(dish)
    dish_mae = 0.0

    # Registry lookup
    try:
        registry = _load_cached(f'{config.model_dir}/champion_registry.pkl')
        dish_info = registry[dish]
        if model == 'auto':
            model = dish_info['model']
        dish_mae = dish_info['all_mae'].get(model, 0.0) if dish_info['all_mae'] else 0.0
    except Exception:
        if model == 'auto':
            model = 'prophet'

    # Average-only dishes
    if model == 'average':
        try:
            avg_sales = _load_cached(f'{config.model_dir}/average_{safe_name}.pkl')
        except Exception:
            avg_sales = 0
        results = []
        for day_offset in range(config.forecast_horizon):
            d = dt + pd.Timedelta(days=day_offset)
            results.append({
                "Dish": dish, "Date": d.strftime('%Y-%m-%d'),
                "Model Used": "AVERAGE",
                "Prediction": avg_sales,
                "Prediction_Lower": avg_sales,
                "Prediction_Upper": avg_sales,
                "Explanation": {"Trend": float(avg_sales), "Seasonality": 0.0,
                                "Holiday": 0.0, "Weather": 0.0}
            })
        return results

    # Build base context row
    local_hols = holidays.SG(years=config.holiday_years) if country == 'SG' \
        else holidays.CN(years=config.holiday_years)
    is_hol = 1 if dt in local_hols else 0
    temp = estimate_temperature(dt, country)

    base_future = pd.DataFrame({
        'ds': [dt],
        'rain_lunch_vol': [rain_forecast],
        'temperature': [temp],
        'is_public_holiday': [is_hol],
        'day_of_week': [dt.dayofweek],
        'month': [dt.month]
    })

    try:
        if model == 'prophet':
            mp = _load_cached(f'{config.model_dir}/prophet_{safe_name}.pkl')
            future_dates = mp.make_future_dataframe(periods=config.forecast_horizon)
            future_dates = future_dates.tail(config.forecast_horizon).copy()
            future_dates['rain_lunch_vol'] = rain_forecast
            future_dates['temperature'] = future_dates['ds'].apply(
                lambda d: estimate_temperature(d, country))
            forecast = mp.predict(future_dates)

            results = []
            for _, row in forecast.iterrows():
                yhat = row['yhat']
                qty = int(max(0, yhat))
                trend = row['trend']
                holiday_val = row['holidays'] if 'holidays' in row else 0.0
                weather = row['extra_regressors_additive'] if 'extra_regressors_additive' in row else 0.0
                seasonality = yhat - trend - holiday_val - weather

                results.append({
                    "Dish": dish,
                    "Date": row['ds'].strftime('%Y-%m-%d'),
                    "Model Used": "PROPHET",
                    "Prediction": qty,
                    "Prediction_Lower": int(max(0, yhat - dish_mae)),
                    "Prediction_Upper": int(yhat + dish_mae),
                    "Explanation": {
                        "Trend": round(trend, 1),
                        "Seasonality": round(seasonality, 1),
                        "Holiday": round(holiday_val, 1),
                        "Weather": round(weather, 1)
                    }
                })
            return results

        elif model in ('catboost', 'xgboost'):
            tree_model = _load_cached(f'{config.model_dir}/{model}_{safe_name}.pkl')
            recent = _load_cached(f'{config.model_dir}/recent_sales_{safe_name}.pkl')

            multiday = _predict_tree_multiday(
                tree_model, base_future, recent, config, country, dish_mae)

            results = []
            for entry in multiday:
                results.append({
                    "Dish": dish,
                    "Date": entry['date'],
                    "Model Used": model.upper(),
                    "Prediction": entry['qty'],
                    "Prediction_Lower": entry['lower'],
                    "Prediction_Upper": entry['upper'],
                    "Explanation": entry['explanation']
                })
            return results

    except Exception as e:
        return [{"Error": f"Model error for {dish}: {str(e)}"}]

## Step 6: Parallel Execution & Results
Run all dishes in parallel using `ProcessPoolExecutor`. Each dish is independently processed across CPU cores.
After completion, aggregate results into a leaderboard and save the champion registry.

In [None]:
# --- PARALLEL EXECUTION ---
if __name__ == "__main__":
    # 1. Load and prepare the master DataFrame
    # This function also cleans the data by normalizing dates and aggregating sales.
    raw_df = fetch_training_data()

    # 2. Define global context for this run
    lat_in, lon_in = 31.23, 121.47
    CFG = PipelineConfig()
    enriched_df, country = add_local_context(raw_df, lat_in, lon_in)

    # 4. Run the full training pipeline in parallel for each dish
    unique_dishes = enriched_df['dish'].unique()
    results = []

    print(f"\n{'='*95}")
    print(f"STARTING PARALLEL TRAINING FOR {len(unique_dishes)} DISHES "
          f"({CFG.max_workers} workers, {CFG.n_optuna_trials} Optuna trials each)")
    print(f"{'='*95}")

    # Use ProcessPoolExecutor to run the `process_dish` function for each dish on a separate core
    with ProcessPoolExecutor(max_workers=CFG.max_workers) as executor:
        # Create a dictionary mapping future objects to dish names for easy lookup
        futures = {
            executor.submit(process_dish, dish, enriched_df, country, CFG): dish
            for dish in unique_dishes
        }
        # Process results as they are completed
        for future in as_completed(futures):
            dish_name = futures[future]
            try:
                result = future.result()
                results.append(result)
                # Print real-time progress for each dish
                if result['champion'] == 'average':
                    print(f"  {dish_name:<35} | AVG (short data) -> avg_sales={result['avg_sales']}")
                else:
                    mae = result['mae']
                    print(f"  {dish_name:<35} | P={mae['prophet']:<7} C={mae['catboost']:<7} "
                          f"X={mae['xgboost']:<7} -> {result['champion'].upper()}")
            except Exception as e:
                print(f"  {dish_name:<35} | ERROR: {e}")

    # 5. Aggregate and display the final results from all parallel runs
    champion_map = {}
    all_predictions = {}
    results_rows = []

    for r in results:
        dish = r['dish']
        champion_map[dish] = {
            'model': r['champion'],
            'mae': r.get('champion_mae', 0.0),
            'all_mae': r['mae']
        }
        if r['backtest_preds'] is not None:
            all_predictions[dish] = r['backtest_preds']

        if r['champion'] == 'average':
            results_rows.append({
                'Dish': dish, 'Prophet MAE': '-', 'CatBoost MAE': '-',
                'XGBoost MAE': '-', 'Winner': 'AVERAGE'
            })
        else:
            results_rows.append({
                'Dish': dish,
                'Prophet MAE': r['mae']['prophet'],
                'CatBoost MAE': r['mae']['catboost'],
                'XGBoost MAE': r['mae']['xgboost'],
                'Winner': r['champion'].upper()
            })

    with open(f'{CFG.model_dir}/champion_registry.pkl', 'wb') as f:
        pickle.dump(champion_map, f)

    clear_model_cache()

    results_table = pd.DataFrame(results_rows)

    print(f"\n{'='*50}")
    print(f"MODEL LEADERBOARD (Lower MAE is Better)")
    print(f"{'='*50}")
    display(results_table)

In [None]:
# --- VISUALIZATION A: MAE Comparison Bar Chart ---
# Filter to ML-trained dishes only (exclude average-only)
ml_rows = results_table[results_table['Winner'] != 'AVERAGE'].copy()

if len(ml_rows) > 0:
    fig, ax = plt.subplots(figsize=(16, 6))

    dishes = ml_rows['Dish']
    x = np.arange(len(dishes))
    width = 0.25

    bars_p = ax.bar(x - width, ml_rows['Prophet MAE'].astype(float), width,
                    label='Prophet', color='#4C72B0')
    bars_c = ax.bar(x, ml_rows['CatBoost MAE'].astype(float), width,
                    label='CatBoost', color='#DD8452')
    bars_x = ax.bar(x + width, ml_rows['XGBoost MAE'].astype(float), width,
                    label='XGBoost', color='#55A868')

    for i, (_, row) in enumerate(ml_rows.iterrows()):
        p_mae = float(row['Prophet MAE'])
        c_mae = float(row['CatBoost MAE'])
        x_mae = float(row['XGBoost MAE'])
        winner_mae = min(p_mae, c_mae, x_mae)
        if row['Winner'] == 'PROPHET':
            offset = -width
        elif row['Winner'] == 'CATBOOST':
            offset = 0
        else:
            offset = width
        ax.plot(x[i] + offset, winner_mae, marker='*', color='gold', markersize=14, zorder=5)

    ax.set_xlabel('Dish')
    ax.set_ylabel('MAE (plates)')
    ax.set_title('Model MAE Comparison by Dish (lower is better)')
    ax.set_xticks(x)
    ax.set_xticklabels(dishes, rotation=45, ha='right', fontsize=8)
    ax.legend()
    ax.yaxis.set_minor_locator(ticker.AutoMinorLocator())
    plt.tight_layout()
    plt.show()

# Show average-only dishes if any
avg_rows = results_table[results_table['Winner'] == 'AVERAGE']
if len(avg_rows) > 0:
    print(f"\nDishes using simple average (< {CFG.min_ml_days} days of data):")
    for _, row in avg_rows.iterrows():
        print(f"  - {row['Dish']}")

In [None]:
# --- VISUALIZATION B: Actual vs Predicted (Last Fold, Winning Model) ---
# Only plot dishes that went through ML (have backtest predictions)
ml_dishes = [r for r in results if r['champion'] != 'average' and r['backtest_preds'] is not None]

if len(ml_dishes) > 0:
    n_dishes = len(ml_dishes)
    ncols = 4
    nrows = int(np.ceil(n_dishes / ncols))

    fig, axes = plt.subplots(nrows, ncols, figsize=(18, 4 * nrows), squeeze=False)

    for idx, r in enumerate(ml_dishes):
        ax = axes[idx // ncols][idx % ncols]
        dish = r['dish']
        winner = r['champion']

        preds = r['backtest_preds'].get(winner)
        if preds is not None:
            dates = pd.to_datetime(preds['dates'])
            ax.plot(dates, preds['actual'], label='Actual', color='#333333', linewidth=1.5)
            ax.plot(dates, preds['predicted'], label='Predicted', color='#E24A33',
                    linewidth=1.5, linestyle='--')
            ax.fill_between(dates, preds['actual'], preds['predicted'],
                            alpha=0.15, color='#E24A33')

        ax.set_title(f"{dish}\n({winner.upper()})", fontsize=8, fontweight='bold')
        ax.tick_params(axis='x', rotation=30, labelsize=6)
        ax.tick_params(axis='y', labelsize=7)
        if idx == 0:
            ax.legend(fontsize=7)

    for idx in range(n_dishes, nrows * ncols):
        axes[idx // ncols][idx % ncols].set_visible(False)

    fig.suptitle('Actual vs Predicted Sales (Last CV Fold, Winning Model)', fontsize=13, y=1.01)
    plt.tight_layout()
    plt.show()
else:
    print("No ML-trained dishes to plot backtests for.")

In [None]:
# --- VISUALIZATION C: Multi-Day Forecast (14 days) ---
forecast_date = '2026-05-20'
rain_input = 10.0

all_forecasts = {}
day1_summary = []

for dish_name in enriched_df['dish'].unique():
    preds = get_prediction(
        dish=dish_name, date_str=forecast_date,
        lat=lat_in, lon=lon_in, rain_forecast=rain_input
    )
    if preds and 'Error' not in preds[0]:
        all_forecasts[dish_name] = preds
        p0 = preds[0]
        day1_summary.append({
            'Dish': p0['Dish'],
            'Day 1 Qty': p0['Prediction'],
            'Lower': p0['Prediction_Lower'],
            'Upper': p0['Prediction_Upper'],
            'Model': p0['Model Used'],
            'Trend': p0['Explanation']['Trend'],
            'Seasonality': p0['Explanation']['Seasonality'],
            'Holiday': p0['Explanation']['Holiday'],
            'Weather': p0['Explanation']['Weather']
        })

# Day-1 summary table
day1_df = pd.DataFrame(day1_summary)
print(f"Forecast starting: {forecast_date} | Rain: {rain_input}mm | Horizon: {CFG.forecast_horizon} days")
print(f"{'='*90}")
display(day1_df)

# Line chart: 14-day forecast per dish with confidence bands
n_dishes = len(all_forecasts)
if n_dishes > 0:
    ncols = 4
    nrows = int(np.ceil(n_dishes / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(18, 4 * nrows), squeeze=False)

    colors_map = {'PROPHET': '#4C72B0', 'CATBOOST': '#DD8452', 'XGBOOST': '#55A868', 'AVERAGE': '#999999'}

    for idx, (dish_name, preds) in enumerate(all_forecasts.items()):
        ax = axes[idx // ncols][idx % ncols]
        dates = [pd.to_datetime(p['Date']) for p in preds]
        qtys = [p['Prediction'] for p in preds]
        lowers = [p['Prediction_Lower'] for p in preds]
        uppers = [p['Prediction_Upper'] for p in preds]
        model_used = preds[0]['Model Used']
        color = colors_map.get(model_used, '#333333')

        ax.plot(dates, qtys, marker='o', markersize=3, color=color, linewidth=1.5,
                label=model_used)
        ax.fill_between(dates, lowers, uppers, alpha=0.2, color=color)

        ax.set_title(f"{dish_name}\n({model_used})", fontsize=8, fontweight='bold')
        ax.tick_params(axis='x', rotation=30, labelsize=6)
        ax.tick_params(axis='y', labelsize=7)
        ax.legend(fontsize=7)

    for idx in range(n_dishes, nrows * ncols):
        axes[idx // ncols][idx % ncols].set_visible(False)

    fig.suptitle(f'{CFG.forecast_horizon}-Day Rolling Forecast per Dish (with Confidence Bands)',
                 fontsize=13, y=1.01)
    plt.tight_layout()
    plt.show()