In [None]:
import numpy as np
import pandas as pd
import statsmodels.api as sm
from arch import arch_model

%matplotlib widget
import matplotlib.pyplot as plt

from itables import init_notebook_mode
init_notebook_mode(all_interactive=True)


class ARIMA_GARCH:
    def __init__(self, ts, exog=None):
        """Initialize with time series data."""
        self.data = ts
        self.exog = exog
        self.arima_model = None
        self.garch_model = None
        self.arima_residuals = None

    def fit_arima(self, order=(1, 1, 1), seasonal=(0, 0, 0, 0), trend='c'):
        """Fits an ARIMA model and stores residuals."""
        model = sm.tsa.statespace.SARIMAX(self.data, self.exog, seasonal_order=seasonal, trend=trend, order=order)
        self.arima_model = model.fit()
        self.arima_residuals = self.arima_model.resid

    def fit_garch(self, p=1, q=1):
        """Fits a GARCH model on ARIMA residuals."""
        if self.arima_residuals is None:
            raise ValueError("Fit ARIMA first before GARCH.")

        model = arch_model(self.arima_residuals, vol='Garch', p=p, q=q, dist="StudentsT")
        self.garch_model = model.fit(disp='off')

    def predict(self, steps=1):
        """Forecasts mean (ARIMA) and volatility (GARCH)."""
        if self.arima_model is None or self.garch_model is None:
            raise ValueError("Fit both ARIMA and GARCH before predicting.")

        arima_forecast = self.arima_model.forecast(steps=steps).values
        garch_forecast = self.garch_model.forecast(horizon=steps).mean.values.flatten()
        forecast_index = pd.bdate_range(
            start=self.data.index[-1] + pd.tseries.offsets.BusinessDay(1), periods=steps)

        # Mean and volatility
        return pd.Series(arima_forecast, index=forecast_index), pd.Series(garch_forecast, index=forecast_index)

    def plot_results(self, steps=10):
        """Plots historical data and recent predictions using `predict`."""
        if self.arima_model is None or self.garch_model is None:
            raise ValueError("Fit both ARIMA and GARCH before plotting.")

        # Get last `steps` predicted values using `predict`
        last_index = len(self.data) - steps
        arima_predicted = self.arima_model.predict(start=last_index, end=len(self.data) - 1)
        garch_predicted = self.garch_model.conditional_volatility[-steps:]

        plt.figure(figsize=(10, 5))
        plt.plot(self.data, label='Actual Data', color='black')
        plt.plot(arima_predicted, label='ARIMA Predicted', linestyle='dashed', color='blue')
        plt.fill_between(arima_predicted.index,
                         arima_predicted - garch_predicted,
                         arima_predicted + garch_predicted,
                         color='gray', alpha=0.3, label='Volatility (±1 std)')

        plt.legend()
        plt.title(f"ARIMA+GARCH Predictions (Last {steps} Steps)".format(steps))
        plt.show()

In [None]:
dat = pd.read_csv("SP.csv", index_col="Date", parse_dates=True)[-100:]

In [None]:
dat.tail(50)

In [None]:
# Example Usage

model = ARIMA_GARCH(dat['Adj Close'])
model.fit_arima()
model.fit_garch()
arima, garch = model.predict(5)

In [None]:
model.plot_results(steps=20)