In [None]:
# ====================================
# 1) Imports
# ====================================
import numpy as np
import pandas as pd
import yfinance as yf
import itertools
import warnings
import matplotlib.pyplot as plt
from statsmodels.tsa.arima.model import ARIMA
from statsmodels.tsa.statespace.sarimax import SARIMAX

warnings.filterwarnings("ignore")

# ====================================
# 2) Utils
# ====================================
def ensure_datetime_freq(series: pd.Series, use_bdays=True) -> pd.Series:
    s = series.copy()
    if not isinstance(s.index, pd.DatetimeIndex):
        s.index = pd.to_datetime(s.index, errors="coerce")
    s = s[~s.index.isna()].sort_index()
    if pd.infer_freq(s.index) is None:
        full_idx = pd.bdate_range(s.index.min(), s.index.max()) if use_bdays else pd.date_range(s.index.min(), s.index.max())
        s = s.reindex(full_idx).ffill()
    return s

# Grid search ARIMA
def grid_search_arima(series, p_range=(0,4), d_range=(0,3), q_range=(0,4)):
    best_aic, best_order = np.inf, None
    for p,d,q in itertools.product(range(*p_range), range(*d_range), range(*q_range)):
        try:
            model = ARIMA(series, order=(p,d,q)).fit()
            if model.aic < best_aic:
                best_aic, best_order = model.aic, (p,d,q)
        except:
            continue
    return best_order, best_aic

# Grid search SARIMA
def grid_search_sarima(series, m=20,
                       p_range=(0,3), d_range=(0,2), q_range=(0,3),
                       P_range=(0,2), D_range=(0,2), Q_range=(0,2)):
    best_aic, best_order, best_seasonal = np.inf, None, None
    for p,d,q in itertools.product(range(*p_range), range(*d_range), range(*q_range)):
        for P,D,Q in itertools.product(range(*P_range), range(*D_range), range(*Q_range)):
            try:
                model = SARIMAX(series, order=(p,d,q), seasonal_order=(P,D,Q,m),
                                enforce_stationarity=False,
                                enforce_invertibility=False).fit(disp=False)
                if model.aic < best_aic:
                    best_aic, best_order, best_seasonal = model.aic, (p,d,q), (P,D,Q,m)
            except:
                continue
    return best_order, best_seasonal, best_aic

# Grid search SARIMAX
def grid_search_sarimax(series, exog, m=20,
                        p_range=(0,3), d_range=(0,2), q_range=(0,3),
                        P_range=(0,2), D_range=(0,2), Q_range=(0,2)):
    best_aic, best_order, best_seasonal = np.inf, None, None
    for p,d,q in itertools.product(range(*p_range), range(*d_range), range(*q_range)):
        for P,D,Q in itertools.product(range(*P_range), range(*D_range), range(*Q_range)):
            try:
                model = SARIMAX(series, order=(p,d,q), seasonal_order=(P,D,Q,m),
                                exog=exog,
                                enforce_stationarity=False,
                                enforce_invertibility=False).fit(disp=False)
                if model.aic < best_aic:
                    best_aic, best_order, best_seasonal = model.aic, (p,d,q), (P,D,Q,m)
            except:
                continue
    return best_order, best_seasonal, best_aic


In [None]:
# ====================================
# 3) Run for selected tickers
# ====================================
TICKERS = ["AEP", "DUK", "SO", "ED", "EXC"]
PERIOD = "5y"

results = []

for ticker in TICKERS:
    print(f"\n=== {ticker} ===")
    data = yf.download(ticker, period=PERIOD, progress=False, auto_adjust=True)["Close"].dropna()
    data = ensure_datetime_freq(data)

    # Exogenous example: oil
    exog = yf.download("CL=F", period=PERIOD, progress=False, auto_adjust=True)["Close"].dropna()
    exog = ensure_datetime_freq(exog).reindex(data.index).ffill()

    # ARIMA
    arima_order, arima_aic = grid_search_arima(data)
    print(f"ARIMA best {arima_order} AIC={arima_aic:.2f}")

    # SARIMA
    sarima_order, sarima_seasonal, sarima_aic = grid_search_sarima(data, m=20)
    print(f"SARIMA best {sarima_order} x {sarima_seasonal} AIC={sarima_aic:.2f}")

    # SARIMAX
    sarimax_order, sarimax_seasonal, sarimax_aic = grid_search_sarimax(data, exog, m=20)
    print(f"SARIMAX best {sarimax_order} x {sarimax_seasonal} AIC={sarimax_aic:.2f}")

    results.append({
        "symbol": ticker,
        "ARIMA order": arima_order, "ARIMA AIC": arima_aic,
        "SARIMA order": sarima_order, "SARIMA seasonal": sarima_seasonal, "SARIMA AIC": sarima_aic,
        "SARIMAX order": sarimax_order, "SARIMAX seasonal": sarimax_seasonal, "SARIMAX AIC": sarimax_aic
    })



=== AEP ===
ARIMA best (0, 1, 0) AIC=3769.98


In [None]:
# ====================================
# 4) Show results as DataFrame
# ====================================
df_results = pd.DataFrame(results)
df_results


In [None]:
# ====================================
# 5) Plot example forecasts
# ====================================
fig, axes = plt.subplots(len(TICKERS), 1, figsize=(12, 4*len(TICKERS)))

if len(TICKERS) == 1:
    axes = [axes]

for i, row in df_results.iterrows():
    ticker = row["symbol"]
    data = yf.download(ticker, period=PERIOD, progress=False, auto_adjust=True)["Close"].dropna()
    data = ensure_datetime_freq(data)

    order = row["ARIMA order"]
    model = ARIMA(data, order=order).fit()
    forecast = model.forecast(steps=30)

    axes[i].plot(data.index, data, label="Historical")
    axes[i].plot(forecast.index, forecast, label=f"ARIMA{order}", color="red")
    axes[i].set_title(f"{ticker} Forecasts")
    axes[i].legend()

plt.tight_layout()
plt.show()


In [None]:
# ====================================
# 6) Export MODEL_PARAMS dict
# ====================================
MODEL_PARAMS = {
    row["symbol"]: {
        "arima": {"order": row["ARIMA order"]},
        "sarima": {"order": row["SARIMA order"], "seasonal_order": row["SARIMA seasonal"]},
        "sarimax": {"order": row["SARIMAX order"], "seasonal_order": row["SARIMAX seasonal"]},
    }
    for _, row in df_results.iterrows()
}

import json
with open("best_model_params.json", "w") as f:
    json.dump(MODEL_PARAMS, f, indent=2)

MODEL_PARAMS
