# Day 15: SARIMA Models - Seasonal ARIMA

## Extending ARIMA with Seasonal Components

Implement SARIMA(p,d,q)(P,D,Q)s models to capture seasonal patterns in time series data with yearly or periodic cycles.

## 1. Import Libraries and Load Data

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.subplots as sp
import warnings
warnings.filterwarnings('ignore')

# Time Series
from statsmodels.tsa.stattools import adfuller
from statsmodels.tsa.seasonal import seasonal_decompose
from statsmodels.tsa.statespace.sarimax import SARIMAX
from sklearn.metrics import mean_squared_error, mean_absolute_error

# Display settings
pd.set_option('display.max_columns', None)
plt.style.use('seaborn-v0_8-darkgrid')

print("✓ Libraries imported successfully")

✓ Libraries imported successfully


In [2]:
# Load daily gold prices
try:
    df = pd.read_csv('../data/gold_prices.csv', parse_dates=['Date'])
except:
    print("⚠ Data file not found")
    df = None

if df is not None:
    # Ensure formatting
    if 'Price' not in df.columns:
        df = df.rename(columns={'Adj Close': 'Price'})
    df = df.drop_duplicates(subset=['Date']).sort_values('Date').reset_index(drop=True)

    print(f"Daily data shape: {df.shape}")
    print(f"Date range: {df['Date'].min()} to {df['Date'].max()}")
    
    # Aggregate to monthly
    df['YearMonth'] = df['Date'].dt.to_period('M')
    df_monthly = df.groupby('YearMonth')['Price'].mean().reset_index()
    df_monthly['Date'] = df_monthly['YearMonth'].dt.to_timestamp()
    df_monthly = df_monthly[['Date', 'Price']].reset_index(drop=True)
    
    print(f"\nMonthly data shape: {df_monthly.shape}")
    print(f"Monthly data range: {df_monthly['Date'].min().date()} to {df_monthly['Date'].max().date()}")
    print("\nMonthly Price Summary:")
    print(df_monthly['Price'].describe())

Daily data shape: (2515, 2)
Date range: 2016-01-11 00:00:00 to 2026-01-09 00:00:00

Monthly data shape: (121, 2)
Monthly data range: 2016-01-01 to 2026-01-01

Monthly Price Summary:
count    121.000000
mean     173.130152
std       63.845037
min      105.368572
25%      122.885238
50%      164.941904
75%      182.140434
max      409.205836
Name: Price, dtype: float64


## 2. Seasonal Decomposition

In [3]:
# Seasonal decomposition (need at least 24 months)
if len(df_monthly) >= 24:
    decomposition = seasonal_decompose(df_monthly['Price'], model='additive', period=12)
    
    trend = decomposition.trend
    seasonal = decomposition.seasonal
    residual = decomposition.resid
    
    print("Decomposition Analysis:")
    print(f"\nTrend Component:")
    print(f"  Mean: {trend.dropna().mean():.2f}")
    print(f"  Std Dev: {trend.dropna().std():.2f}")
    
    print(f"\nSeasonal Component:")
    print(f"  Mean: {seasonal.dropna().mean():.4f}")
    print(f"  Std Dev: {seasonal.dropna().std():.2f}")
    print(f"  Range: {seasonal.dropna().min():.2f} to {seasonal.dropna().max():.2f}")
    
    print(f"\nResidual Component:")
    print(f"  Mean: {residual.dropna().mean():.4f}")
    print(f"  Std Dev: {residual.dropna().std():.2f}")
    
    # Seasonal strength
    seasonal_var = seasonal.dropna().var()
    residual_var = residual.dropna().var()
    seasonal_strength = 1 - (residual_var / (seasonal_var + residual_var))
    
    print(f"\nSeasonal Strength: {seasonal_strength:.4f}")
    print(f"  Interpretation: {seasonal_strength*100:.1f}% of variation is seasonal")
    if seasonal_strength > 0.1:
        print(f"  ✓ Strong enough for SARIMA")
    else:
        print(f"  ⚠ Weak seasonality, ARIMA may suffice")

Decomposition Analysis:

Trend Component:
  Mean: 166.12
  Std Dev: 46.74

Seasonal Component:
  Mean: -0.0102
  Std Dev: 2.21
  Range: -4.02 to 4.53

Residual Component:
  Mean: -0.5574
  Std Dev: 5.33

Seasonal Strength: 0.1468
  Interpretation: 14.7% of variation is seasonal
  ✓ Strong enough for SARIMA


In [4]:
# Visualize decomposition
fig = sp.make_subplots(
    rows=4, cols=1,
    subplot_titles=('Original', 'Trend', 'Seasonal', 'Residual'),
    vertical_spacing=0.08
)

fig.add_trace(
    go.Scatter(x=df_monthly['Date'], y=df_monthly['Price'], name='Original', 
              line=dict(color='#FFD700', width=2)),
    row=1, col=1
)

fig.add_trace(
    go.Scatter(x=df_monthly['Date'], y=trend, name='Trend',
              line=dict(color='#FF6B6B', width=2)),
    row=2, col=1
)

fig.add_trace(
    go.Scatter(x=df_monthly['Date'], y=seasonal, name='Seasonal',
              line=dict(color='#4ECDC4', width=2)),
    row=3, col=1
)

fig.add_trace(
    go.Scatter(x=df_monthly['Date'], y=residual, name='Residual',
              line=dict(color='#95E1D3', width=1)),
    row=4, col=1
)

fig.update_yaxes(title_text="Price ($)", row=1, col=1)
fig.update_yaxes(title_text="Trend", row=2, col=1)
fig.update_yaxes(title_text="Seasonal", row=3, col=1)
fig.update_yaxes(title_text="Residual", row=4, col=1)
fig.update_layout(height=900, hovermode='x unified', showlegend=False)
fig.show()

## 3. Stationarity Testing

In [5]:
# ADF test helper
def adf_test(series, name):
    result = adfuller(series, autolag='AIC')
    p_val = result[1]
    is_stat = p_val <= 0.05
    
    print(f"\n{name}:")
    print(f"  ADF Statistic: {result[0]:.6f}")
    print(f"  P-value: {p_val:.6f}")
    print(f"  Status: {'✓ STATIONARY' if is_stat else '✗ NON-STATIONARY'}")
    
    return is_stat, p_val

price_monthly = df_monthly['Price'].values

# Test original
stat_orig, p_orig = adf_test(price_monthly, "Original Series")

# First difference
diff_d1 = np.diff(price_monthly, n=1)
stat_d1, p_d1 = adf_test(diff_d1, "First Difference (d=1)")

# Seasonal difference
diff_seasonal = np.diff(price_monthly, n=12)
stat_seasonal, p_seasonal = adf_test(diff_seasonal, "Seasonal Difference (D=1, 12m)")

# Combined
diff_both = np.diff(diff_d1, n=12)
stat_both, p_both = adf_test(diff_both, "Combined d=1, D=1")

print(f"\n{'='*50}")
print(f"Recommendation: d=1, D={'1' if p_seasonal <= 0.05 else '0'}")


Original Series:
  ADF Statistic: 4.389512
  P-value: 1.000000
  Status: ✗ NON-STATIONARY

First Difference (d=1):
  ADF Statistic: -7.049122
  P-value: 0.000000
  Status: ✓ STATIONARY

Seasonal Difference (D=1, 12m):
  ADF Statistic: -11.117740
  P-value: 0.000000
  Status: ✓ STATIONARY

Combined d=1, D=1:
  ADF Statistic: -9.775437
  P-value: 0.000000
  Status: ✓ STATIONARY

Recommendation: d=1, D=1


## 4. Train-Test Split

In [6]:
train_size = int(len(df_monthly) * 0.8)
train_data = df_monthly[:train_size].copy()
test_data = df_monthly[train_size:].copy()

print(f"Training set: {len(train_data)} months ({len(train_data)/len(df_monthly)*100:.1f}%)")
print(f"Test set: {len(test_data)} months ({len(test_data)/len(df_monthly)*100:.1f}%)")
print(f"\nTrain period: {train_data['Date'].min().date()} to {train_data['Date'].max().date()}")
print(f"Test period: {test_data['Date'].min().date()} to {test_data['Date'].max().date()}")

# Visualize split
fig = go.Figure()
fig.add_trace(go.Scatter(
    x=train_data['Date'], y=train_data['Price'],
    mode='lines', name='Training Data',
    line=dict(color='#FFD700', width=2)
))
fig.add_trace(go.Scatter(
    x=test_data['Date'], y=test_data['Price'],
    mode='lines', name='Test Data',
    line=dict(color='#FF6B6B', width=2)
))
fig.add_vline(x=train_data['Date'].max(), line_dash="dash", line_color="gray")
fig.update_layout(
    title="Train-Test Split (80/20)",
    xaxis_title="Date",
    yaxis_title="Price ($)",
    hovermode='x unified',
    height=400
)
fig.show()

Training set: 96 months (79.3%)
Test set: 25 months (20.7%)

Train period: 2016-01-01 to 2023-12-01
Test period: 2024-01-01 to 2026-01-01


## 5. Fit SARIMA Models

In [7]:
# Define SARIMA configurations
sarima_configs = [
    ((1, 1, 0), (0, 0, 0, 12)),  # Simple AR with trend
    ((0, 1, 1), (1, 0, 0, 12)),  # MA with seasonal AR
    ((0, 1, 1), (0, 1, 0, 12)),  # MA with seasonal differencing
]

print(f"Fitting {len(sarima_configs)} SARIMA models on {len(train_data)} months...\n")
print(f"{'Model':<30} {'AIC':<12} {'BIC':<12} {'RMSE':<10}")
print("-" * 64)

results = []

for order, seasonal_order in sarima_configs:
    try:
        model = SARIMAX(
            train_data['Price'], 
            order=order, 
            seasonal_order=seasonal_order,
            enforce_stationarity=False,
            enforce_invertibility=False
        )
        fitted_model = model.fit(disp=False, maxiter=50)
        
        # Forecast
        forecast = fitted_model.get_forecast(steps=len(test_data))
        forecast_values = forecast.predicted_mean.values
        
        # Metrics
        aic = fitted_model.aic
        bic = fitted_model.bic
        rmse = np.sqrt(mean_squared_error(test_data['Price'], forecast_values))
        
        results.append({
            'Model': f"ARIMA{order}{seasonal_order}",
            'AIC': aic,
            'BIC': bic,
            'RMSE': rmse,
            'Forecast': forecast_values,
            'FittedModel': fitted_model
        })
        
        print(f"{str(order)+str(seasonal_order):<30} {aic:<12.2f} {bic:<12.2f} {rmse:<10.2f}")
    except Exception as e:
        print(f"{str(order)+str(seasonal_order):<30} Failed: {str(e)[:35]}")

results_df = pd.DataFrame(results)
if len(results) > 0:
    optimal_idx = results_df['BIC'].idxmin()
    print(f"\n✓ Best Model (BIC): {results_df.loc[optimal_idx, 'Model']}")
    print(f"  BIC: {results_df.loc[optimal_idx, 'BIC']:.2f}")

Fitting 3 SARIMA models on 96 months...

Model                          AIC          BIC          RMSE      
----------------------------------------------------------------
(1, 1, 0)(0, 0, 0, 12)         547.68       552.77       105.97    
(0, 1, 1)(1, 0, 0, 12)         487.80       495.05       105.05    
(0, 1, 1)(0, 1, 0, 12)         523.08       527.87       81.40     

✓ Best Model (BIC): ARIMA(0, 1, 1)(1, 0, 0, 12)
  BIC: 495.05


## 6. Optimal Model Analysis

In [8]:
optimal_model = results_df.loc[optimal_idx, 'FittedModel']
optimal_forecast = results_df.loc[optimal_idx, 'Forecast']

# Performance metrics
rmse = np.sqrt(mean_squared_error(test_data['Price'], optimal_forecast))
mae = mean_absolute_error(test_data['Price'], optimal_forecast)
mape = np.mean(np.abs((test_data['Price'] - optimal_forecast) / test_data['Price'])) * 100

naive_forecast = np.full(len(test_data), train_data['Price'].iloc[-1])
naive_rmse = np.sqrt(mean_squared_error(test_data['Price'], naive_forecast))
improvement = (naive_rmse - rmse) / naive_rmse * 100

print("Test Set Performance:")
print(f"\nOptimal Model Metrics:")
print(f"  RMSE: {rmse:.2f}")
print(f"  MAE: {mae:.2f}")
print(f"  MAPE: {mape:.2f}%")

print(f"\nNaive Baseline RMSE: {naive_rmse:.2f}")
print(f"Improvement: {improvement:+.2f}%")

print(f"\nModel Summary:")
print(optimal_model.summary())

Test Set Performance:

Optimal Model Metrics:
  RMSE: 105.05
  MAE: 83.99
  MAPE: 27.06%

Naive Baseline RMSE: 107.09
Improvement: +1.91%

Model Summary:
                                      SARIMAX Results                                      
Dep. Variable:                               Price   No. Observations:                   96
Model:             SARIMAX(0, 1, 1)x(1, 0, [], 12)   Log Likelihood                -240.899
Date:                             Mon, 26 Jan 2026   AIC                            487.798
Time:                                     00:16:28   BIC                            495.054
Sample:                                          0   HQIC                           490.713
                                              - 96                                         
Covariance Type:                               opg                                         
                 coef    std err          z      P>|z|      [0.025      0.975]
-------------------------------

## 7. Forecast Visualization

In [9]:
# Get confidence intervals
forecast_obj = optimal_model.get_forecast(steps=len(test_data))
forecast_ci = forecast_obj.conf_int(alpha=0.05)
forecast_mean = forecast_obj.predicted_mean

fig = go.Figure()

# Training data
fig.add_trace(go.Scatter(
    x=train_data['Date'], y=train_data['Price'],
    mode='lines', name='Training Data',
    line=dict(color='#FFD700', width=2)
))

# Test data
fig.add_trace(go.Scatter(
    x=test_data['Date'], y=test_data['Price'],
    mode='lines', name='Actual Test Data',
    line=dict(color='#FF6B6B', width=2)
))

# SARIMA forecast
fig.add_trace(go.Scatter(
    x=test_data['Date'], y=forecast_mean,
    mode='lines', name='SARIMA Forecast',
    line=dict(color='#4ECDC4', width=2, dash='dash')
))

# Confidence interval
fig.add_trace(go.Scatter(
    x=test_data['Date'],
    y=forecast_ci.iloc[:, 0],
    fill=None, mode='lines',
    line_color='rgba(0,0,0,0)',
    name='95% CI Lower'
))

fig.add_trace(go.Scatter(
    x=test_data['Date'],
    y=forecast_ci.iloc[:, 1],
    fill='tonexty', mode='lines',
    line_color='rgba(0,0,0,0)',
    name='95% CI Upper',
    fillcolor='rgba(68, 205, 196, 0.2)'
))

fig.update_layout(
    title=f"SARIMA Gold Price Forecast (RMSE: {rmse:.2f}, Improvement: {improvement:+.2f}%)",
    xaxis_title="Date",
    yaxis_title="Price ($)",
    hovermode='x unified',
    height=500
)
fig.show()

## 8. Residual Diagnostics

In [10]:
residuals = optimal_model.resid
residuals_test = test_data['Price'].values - optimal_forecast

print("Residual Analysis:")
print(f"\nIn-Sample Residuals:")
print(f"  Mean: {residuals.mean():.4f}")
print(f"  Std Dev: {residuals.std():.4f}")
print(f"  Min: {residuals.min():.2f}, Max: {residuals.max():.2f}")

print(f"\nTest Set Residuals:")
print(f"  Mean: {residuals_test.mean():.2f}")
print(f"  Std Dev: {residuals_test.std():.2f}")
print(f"  Min: {residuals_test.min():.2f}, Max: {residuals_test.max():.2f}")

Residual Analysis:

In-Sample Residuals:
  Mean: 1.7282
  Std Dev: 11.5559
  Min: -9.76, Max: 105.37

Test Set Residuals:
  Mean: 83.68
  Std Dev: 63.50
  Min: -1.98, Max: 217.90


In [11]:
# Residual plots
fig = sp.make_subplots(
    rows=1, cols=2,
    subplot_titles=('Residuals Over Time', 'Residual Distribution')
)

fig.add_trace(
    go.Scatter(x=train_data['Date'], y=residuals, mode='markers',
              name='Residuals', marker=dict(color='#FF6B6B')),
    row=1, col=1
)
fig.add_hline(y=0, line_dash="dash", line_color="black", row=1, col=1)

fig.add_trace(
    go.Histogram(x=residuals, nbinsx=30, name='Distribution',
                marker=dict(color='#FFD700')),
    row=1, col=2
)

fig.update_yaxes(title_text="Residual", row=1, col=1)
fig.update_yaxes(title_text="Frequency", row=1, col=2)
fig.update_layout(height=400, showlegend=False)
fig.show()

## 9. Key Insights

In [None]:
print("="*60)
print("KEY INSIGHTS: SARIMA VS ARIMA")
print("="*60)

print("\nSARIMA Advantages:")
print("  ✓ Captures seasonal patterns explicitly")
print("  ✓ Separates seasonal from trend components")
print("  ✓ Better for seasonal data (retail, weather, tourism)")

print("\nSARIMA Components (p,d,q)(P,D,Q)s:")
print("  (p,d,q) = Non-seasonal (trend & differencing)")
print("  (P,D,Q) = Seasonal (yearly/monthly cycles)")
print("  s = Seasonal period (12 for monthly data)")

print("\nWhen to Use SARIMA:")
print("  ✓ Retail sales (holiday seasonality)")
print("  ✓ Weather data (seasonal cycles)")
print("  ✓ Tourism (summer/winter patterns)")
print("  ✓ Utilities (seasonal usage)")
print("  ✗ Non-seasonal data (use ARIMA)")

print("\nARIMA Family Progression:")
print("  Day 12: AR(p) - Autoregressive only")
print("  Day 13: ARMA(p,q) - + Moving Average")
print("  Day 14: ARIMA(p,d,q) - + Differencing")
print("  Day 15: SARIMA(p,d,q)(P,D,Q)s - + Seasonality")

KEY INSIGHTS: SARIMA VS ARIMA

SARIMA Advantages:
  ✓ Captures seasonal patterns explicitly
  ✓ Separates seasonal from trend components
  ✓ Better for seasonal data (retail, weather, tourism)

SARIMA Components (p,d,q)(P,D,Q)s:
  (p,d,q) = Non-seasonal (trend & differencing)
  (P,D,Q) = Seasonal (yearly/monthly cycles)
  s = Seasonal period (12 for monthly data)

When to Use SARIMA:
  ✓ Retail sales (holiday seasonality)
  ✓ Weather data (seasonal cycles)
  ✓ Tourism (summer/winter patterns)
  ✓ Utilities (seasonal usage)
  ✗ Non-seasonal data (use ARIMA)

ARIMA Family Progression:
  Day 12: AR(p) - Autoregressive only
  Day 13: ARMA(p,q) - + Moving Average
  Day 14: ARIMA(p,d,q) - + Differencing
  Day 15: SARIMA(p,d,q)(P,D,Q)s - + Seasonality ← HERE
