# 02 — Forecast plus Drivers
Add a seasonal baseline and driver insights (feature importance or SHAP).

In [None]:

import pandas as pd, numpy as np
from pathlib import Path
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_absolute_error
import matplotlib.pyplot as plt

data_path = Path(__file__).resolve().parents[1] / "data" / "sample_transactions.csv"
df = pd.read_csv(data_path, parse_dates=["date"])
df['week'] = df['date'].dt.to_period('W').apply(lambda r: r.start_time)

agg = df.groupby(['week','product','region'], as_index=False).agg(qty=('qty','sum'),
                                                                  price=('price','mean'),
                                                                  promo=('promo','mean'))

# Create a simple seasonal baseline: last 4-week mean per product-region
agg = agg.sort_values('week')
agg['key'] = agg['product'] + '_' + agg['region']
def seasonal_baseline(s):
    return s.shift(1).rolling(4).mean()
agg['baseline'] = agg.groupby('key')['qty'].apply(seasonal_baseline)

# Simple features for a model
agg['week_num'] = agg['week'].dt.isocalendar().week.astype(int)
X = agg[['price','promo','week_num']].fillna(0)
y = agg['qty']
mask = agg['baseline'].notna()
X = X[mask]; y = y[mask]; base = agg.loc[mask, 'baseline']

# Train simple RF (not ideal for time series, but fine for quick driver signals)
split = int(len(X)*0.8)
X_train, X_test = X.iloc[:split], X.iloc[split:]
y_train, y_test = y.iloc[:split], y.iloc[split:]
base_test = base.iloc[split:]

rf = RandomForestRegressor(n_estimators=200, random_state=42)
rf.fit(X_train, y_train)
pred = rf.predict(X_test)

mae_base = mean_absolute_error(y_test, base_test)
mae_model = mean_absolute_error(y_test, pred)
print({"MAE_baseline": mae_base, "MAE_model": mae_model})

# Feature importances
imp = pd.Series(rf.feature_importances_, index=X.columns).sort_values(ascending=False)
print("Feature importance:\n", imp)

imp.plot(kind="bar")
plt.title("Feature Importance")
plt.tight_layout()
Path("figures").mkdir(exist_ok=True)
plt.savefig("figures/feature_importance.png")
print("Saved figures/feature_importance.png")
