# 01 — SARIMAX con exógenas (por Store)

**Objetivo:** forecasting de `Weekly_Sales` semanal por `Store` usando SARIMAX con variables exógenas.

## Supuesto experimental (oracle exog)
Se asume disponibilidad de todas las covariables exógenas durante el horizonte de predicción (escenario oracle).

## Outputs estándar
- `outputs/predictions/sarimax_exog_predictions.csv` con: `Store, Date, y_true, y_pred, model`
- `outputs/metrics/sarimax_exog_metrics_global.csv`
- `outputs/metrics/sarimax_exog_metrics_by_store.csv`
- `outputs/figures/sarimax_exog_plot_*.png`

In [None]:
# 0) Imports y configuración
from __future__ import annotations

import json
import sys
from pathlib import Path

import numpy as np
import pandas as pd

PROJECT_ROOT = Path.cwd().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from src.common import (
    SplitConfig,
    TEST_WEEKS,
    compute_metrics,
    load_data,
    make_features,
    save_outputs,
    temporal_split,
 )

MODEL_NAME = 'sarimax_exog'
SEED = 42
np.random.seed(SEED)

DATA_PATH = PROJECT_ROOT / 'data' / 'Walmart_Sales.csv'
METADATA_PATH = PROJECT_ROOT / 'outputs' / 'metadata.json'
OUTPUTS_DIR = PROJECT_ROOT / 'outputs'

## 1) Cargar metadata (split + features)
Esto garantiza consistencia entre modelos.

In [None]:
if METADATA_PATH.exists():
    metadata = json.loads(METADATA_PATH.read_text(encoding='utf-8'))
    split = metadata['split']
    feature_cols = metadata['features']
    split_cfg = SplitConfig(
        train_start=pd.Timestamp(split['train_start']),
        train_end=pd.Timestamp(split['train_end']),
        val_start=pd.Timestamp(split['val_start']),
        val_end=pd.Timestamp(split['val_end']),
        test_start=pd.Timestamp(split['test_start']),
        test_end=pd.Timestamp(split['test_end']),
    )
    print('Split (metadata):', split)
    print('N features:', len(feature_cols))
else:
    metadata = None
    split = None
    feature_cols = None
    split_cfg = None
    print('metadata.json not found; will compute split with TEST_WEEKS=', TEST_WEEKS)

Split: {'train_start': '2010-02-05', 'train_end': '2012-07-06', 'val_start': '2012-07-13', 'val_end': '2012-08-31', 'test_start': '2012-09-07', 'test_end': '2012-10-26'}
N features: 16


## 2) Carga de datos + features
- Parseo/orden
- Construcción de lags/rolling (sin leakage)
- Exógenas alineadas por fecha

In [None]:
df = load_data(DATA_PATH)
df_feat, feature_cols_auto = make_features(df, add_calendar=True)
if feature_cols is None:
    feature_cols = feature_cols_auto

# Importante: para entrenar, debes decidir cómo tratar NaNs creados por lags/rolling
# Opción típica: descartar filas con NaNs en features (por store al inicio)
model_df = df_feat.dropna(subset=feature_cols + ['Weekly_Sales']).copy()
model_df.shape

(4095, 19)

## 3) Split temporal
Reutiliza exactamente el split definido en el notebook 00.

In [None]:
if split_cfg is None:
    train_df, val_df, test_df, split_cfg = temporal_split(df, test_weeks=TEST_WEEKS)
else:
    train_df = df[df['Date'].between(split_cfg.train_start, split_cfg.train_end)].copy()
    val_df = df[df['Date'].between(split_cfg.val_start, split_cfg.val_end)].copy()
    test_df = df[df['Date'].between(split_cfg.test_start, split_cfg.test_end)].copy()

# Aplicar el split sobre model_df (ya sin NaNs por lags)
train = model_df[model_df['Date'].between(split_cfg.train_start, split_cfg.train_end)].copy()
val = model_df[model_df['Date'].between(split_cfg.val_start, split_cfg.val_end)].copy()
test = model_df[model_df['Date'].between(split_cfg.test_start, split_cfg.test_end)].copy()

print(len(train), len(val), len(test))

3375 360 360


## 4) Entrenamiento del modelo
Implementación SARIMAX por tienda con exógenas.

In [5]:
from statsmodels.tsa.statespace.sarimax import SARIMAX
from warnings import filterwarnings

filterwarnings("ignore")

# SARIMAX ya modela dependencias temporales; usamos exógenas + calendario (sin lags/rolling)
sarimax_exog_cols = [c for c in feature_cols if not c.startswith("lag_") and not c.startswith("roll_")]

preds = pd.Series(index=test.index, dtype=float)
failed_stores = []

for store, g_train in train.groupby("Store"):
    g_test = test[test["Store"] == store]
    if g_test.empty or g_train.empty:
        continue

    y_train = g_train["Weekly_Sales"].astype(float)
    X_train = g_train[sarimax_exog_cols].astype(float)
    X_test = g_test[sarimax_exog_cols].astype(float)

    try:
        model = SARIMAX(
            y_train,
            exog=X_train,
            order=(1, 1, 1),
            seasonal_order=(0, 0, 0, 0),
            enforce_stationarity=False,
            enforce_invertibility=False,
        )
        res = model.fit(disp=False)
        forecast = res.get_forecast(steps=len(g_test), exog=X_test)
        y_pred_store = forecast.predicted_mean
        preds.loc[g_test.index] = y_pred_store.values
    except Exception:
        failed_stores.append(int(store))
        preds.loc[g_test.index] = y_train.mean()

# Relleno de seguridad si alguna predicción quedó NaN
if preds.isna().any():
    preds = preds.fillna(train["Weekly_Sales"].mean())

y_pred_test = preds.values
print("Failed stores:", len(failed_stores))

Failed stores: 0


## 5) Métricas (MAE, RMSE, sMAPE)
Se reporta:
- Global
- Por store

In [9]:
pred_df = pd.DataFrame({
    'Store': test['Store'].astype(int).values,
    'Date': test['Date'].values,
    'y_true': test['Weekly_Sales'].values,
    'y_pred': np.asarray(y_pred_test, dtype=float),
    'model': MODEL_NAME,
})

global_metrics = compute_metrics(pred_df['y_true'].values, pred_df['y_pred'].values)
metrics_global_df = pd.DataFrame([{'model': MODEL_NAME, **global_metrics}])

by_store = []
for store, g in pred_df.groupby('Store'):
    m = compute_metrics(g['y_true'].values, g['y_pred'].values)
    by_store.append({'model': MODEL_NAME, 'Store': int(store), **m})
metrics_by_store_df = pd.DataFrame(by_store).sort_values('Store')

metrics_global_df, metrics_by_store_df.head()

(          model            MAE           RMSE      sMAPE
 0  sarimax_exog  124157.680285  169022.469951  10.889343,
           model  Store            MAE           RMSE     sMAPE
 0  sarimax_exog      1  122537.806903  130906.969556  7.594821
 1  sarimax_exog      2  145186.688393  153599.094294  7.509888
 2  sarimax_exog      3   30852.759826   33342.089262  7.611823
 3  sarimax_exog      4  143754.257288  154725.388471  6.557401
 4  sarimax_exog      5   10757.029110   13023.135273  3.271099)

## 6) Guardado de outputs estándar

In [10]:
paths = save_outputs(
    model_name=MODEL_NAME,
    predictions=pred_df,
    metrics_global=metrics_global_df,
    metrics_by_store=metrics_by_store_df,
    output_dir=OUTPUTS_DIR,
)
paths

{'predictions': 'c:\\Users\\usuario\\Documents\\Master AI\\TFM\\MEMORIA 2.0\\outputs\\predictions\\sarimax_exog_predictions.csv',
 'metrics_global': 'c:\\Users\\usuario\\Documents\\Master AI\\TFM\\MEMORIA 2.0\\outputs\\metrics\\sarimax_exog_metrics_global.csv',
 'metrics_by_store': 'c:\\Users\\usuario\\Documents\\Master AI\\TFM\\MEMORIA 2.0\\outputs\\metrics\\sarimax_exog_metrics_by_store.csv'}

## 7) Figuras
- 3 tiendas: real vs predicción en test
- Distribución del error (`y_true - y_pred`)

Guardar PNGs en `outputs/figures/`.

In [8]:
import matplotlib.pyplot as plt
import seaborn as sns

FIG_DIR = OUTPUTS_DIR / "figures"
FIG_DIR.mkdir(parents=True, exist_ok=True)

# Selección de 3 tiendas (mayor media de ventas en test)
top_stores = (
    pred_df.groupby("Store")["y_true"]
    .mean()
    .sort_values(ascending=False)
    .head(3)
    .index
    .tolist()
)

for store in top_stores:
    g = pred_df[pred_df["Store"] == store].sort_values("Date")
    plt.figure(figsize=(10, 4))
    plt.plot(g["Date"], g["y_true"], label="y_true")
    plt.plot(g["Date"], g["y_pred"], label="y_pred")
    plt.title(f"Store {store} — SARIMAX")
    plt.xlabel("Date")
    plt.ylabel("Weekly_Sales")
    plt.legend()
    plt.tight_layout()
    plt.savefig(FIG_DIR / f"{MODEL_NAME}_plot_store_{store}.png", dpi=150)
    plt.close()

# Distribución de error
errors = pred_df["y_true"] - pred_df["y_pred"]
plt.figure(figsize=(8, 4))
sns.histplot(errors, bins=30, kde=True)
plt.title("Error distribution (y_true - y_pred)")
plt.xlabel("Error")
plt.tight_layout()
plt.savefig(FIG_DIR / f"{MODEL_NAME}_plot_error_dist.png", dpi=150)
plt.close()

Matplotlib is building the font cache; this may take a moment.
