# TSA Chapter 4: SARIMA Forecast

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/QuantLet/TSA/blob/main/TSA_ch4/TSA_ch4_sarima_forecast/TSA_ch4_sarima_forecast.ipynb)

24-month SARIMA(0,1,1)(0,1,1)_12 forecast of airline passengers with confidence intervals.

In [None]:
!pip install numpy pandas matplotlib statsmodels -q

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from statsmodels.tsa.seasonal import seasonal_decompose
from statsmodels.tsa.statespace.sarimax import SARIMAX
from statsmodels.tsa.stattools import acf, pacf
from scipy import stats
import warnings
warnings.filterwarnings('ignore')

In [None]:
COLORS = {'blue': '#1A3A6E', 'red': '#DC3545', 'green': '#2E7D32', 'orange': '#E67E22', 'gray': '#666666', 'purple': '#8E44AD'}
BLUE, RED, GREEN, ORANGE, GRAY, PURPLE = COLORS['blue'], COLORS['red'], COLORS['green'], COLORS['orange'], COLORS['gray'], COLORS['purple']

plt.rcParams.update({
    'figure.facecolor': 'none', 'axes.facecolor': 'none', 'savefig.facecolor': 'none',
    'savefig.transparent': True, 'axes.spines.top': False, 'axes.spines.right': False,
    'axes.grid': False, 'font.size': 10, 'axes.titlesize': 12, 'axes.labelsize': 10,
    'xtick.labelsize': 9, 'ytick.labelsize': 9, 'legend.fontsize': 9, 'figure.dpi': 150,
    'lines.linewidth': 1.2, 'axes.linewidth': 0.6, 'legend.facecolor': 'none',
    'legend.framealpha': 0, 'legend.edgecolor': 'none',
})

def save_chart(fig, name):
    fig.savefig(f'{name}.pdf', bbox_inches='tight', transparent=True, dpi=150)
    fig.savefig(f'{name}.png', bbox_inches='tight', transparent=True, dpi=150)
    print(f'Saved: {name}')

In [None]:
# Load airline passengers
data = pd.read_csv('https://raw.githubusercontent.com/jbrownlee/Datasets/master/airline-passengers.csv',
                    parse_dates=['Month'], index_col='Month')
passengers = data['Passengers']
passengers.index.freq = 'MS'
log_passengers = np.log(passengers)

In [None]:
# Train/test split: last 24 months held out
train = log_passengers[:-24]
test = log_passengers[-24:]

# Fit SARIMA(0,1,1)(0,1,1)_12
model_f = SARIMAX(train, order=(0, 1, 1), seasonal_order=(0, 1, 1, 12),
                  enforce_stationarity=False, enforce_invertibility=False)
fit_f = model_f.fit(disp=False)

forecast = fit_f.get_forecast(steps=len(test))
forecast_mean = forecast.predicted_mean
conf_int = forecast.conf_int()

fig, ax = plt.subplots(figsize=(12, 5))

# Training data (last 48 months for clarity)
train_plot = train[-48:]
ax.plot(train_plot.index, np.exp(train_plot.values), color=BLUE, linewidth=1.5,
        label='Training Data')
ax.plot(test.index, np.exp(test.values), color=GREEN, linewidth=2,
        label='Actual', marker='o', markersize=4)
ax.plot(test.index, np.exp(forecast_mean.values), color=RED, linewidth=2,
        linestyle='--', label='SARIMA Forecast')
ax.fill_between(test.index,
                np.exp(conf_int.iloc[:, 0].values),
                np.exp(conf_int.iloc[:, 1].values),
                color=RED, alpha=0.15, label='95% CI')
ax.axvline(x=train.index[-1], color=GRAY, linestyle=':', alpha=0.7, linewidth=2)

ax.set_title('Airline Passengers: SARIMA$(0,1,1)\\times(0,1,1)_{12}$ Forecast',
             fontweight='bold')
ax.set_xlabel('Date')
ax.set_ylabel('Passengers (thousands)')
ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=4, frameon=False)

plt.tight_layout(); save_chart(fig, 'ch4_sarima_forecast'); plt.show()