# Day 18: SARIMA for Seasonality
## Advanced Seasonal ARIMA Modeling

This notebook explores seasonal pattern detection and SARIMA model optimization for gold price forecasting.

## Section 1: Import and Setup

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.subplots as sp
from scipy import stats

from statsmodels.tsa.statespace.sarimax import SARIMAX
from statsmodels.tsa.seasonal import seasonal_decompose
from statsmodels.tsa.stattools import adfuller, acf, pacf
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
from sklearn.metrics import mean_squared_error, mean_absolute_error

import warnings
warnings.filterwarnings('ignore')

print("✓ Imports successful")

## Section 2: Load and Prepare Data

In [None]:
# Load data
df = pd.read_csv('../data/gold_prices.csv', parse_dates=['Date'])
if 'Price' not in df.columns:
    df = df.rename(columns={'Adj Close': 'Price'})
df = df.drop_duplicates(subset=['Date']).sort_values('Date').reset_index(drop=True)

# Train-test split
train_size = int(len(df) * 0.8)
train_data = df[:train_size].copy()
test_data = df[train_size:].copy()

print(f"Data loaded: {len(df)} observations")
print(f"Train: {len(train_data)} | Test: {len(test_data)}")
print(f"Date range: {df['Date'].min().date()} to {df['Date'].max().date()}")

## Section 3: Seasonal Decomposition

In [None]:
# Seasonal decomposition with 252-day period
decomposition = seasonal_decompose(train_data['Price'], model='additive', period=252, extrapolate='fill')

trend = decomposition.trend
seasonal = decomposition.seasonal
residual = decomposition.resid

print(f"Decomposition Results:")
print(f"Trend range: [{trend.min():.2f}, {trend.max():.2f}]")
print(f"Seasonal range: [{seasonal.min():.2f}, {seasonal.max():.2f}]")
print(f"Residual std: {residual.std():.2f}")

# Calculate seasonal strength
seasonal_var = np.var(seasonal)
residual_var = np.var(residual)
seasonal_strength = seasonal_var / (seasonal_var + residual_var)

print(f"\nSeasonal Strength: {seasonal_strength:.4f} ({seasonal_strength*100:.2f}%)")
if seasonal_strength < 0.1:
    print("  → Weak seasonality")
elif seasonal_strength < 0.3:
    print("  → Moderate seasonality")
else:
    print("  → Strong seasonality")

## Section 4: Seasonal Pattern Analysis

In [None]:
# ACF at seasonal lags
acf_vals = acf(train_data['Price'], nlags=300, fft=True)

print("Autocorrelation at seasonal lags (252-day period):")
seasonal_lags = [252, 504, 756, 1008]
for lag in seasonal_lags:
    if lag < len(acf_vals):
        acf_val = acf_vals[lag]
        print(f"  Lag {lag:4d} (year {lag//252}): ACF = {acf_val:.4f}")

# Seasonal peak
seasonal_peak_idx = np.argmax(np.abs(seasonal.values))
seasonal_peak_date = train_data['Date'].iloc[seasonal_peak_idx]
seasonal_peak_value = seasonal.iloc[seasonal_peak_idx]

print(f"\nSeasonal Peak:")
print(f"  Date: {seasonal_peak_date.date()}")
print(f"  Value: {seasonal_peak_value:.2f}")
print(f"  Month: {seasonal_peak_date.month_name()}")

## Section 5: Stationarity Testing

In [None]:
# Original series
adf_stat, adf_p, _, _, _, _ = adfuller(train_data['Price'])
print(f"Original series: ADF p = {adf_p:.6f}", "(NON-STATIONARY)" if adf_p > 0.05 else "(STATIONARY)")

# First difference
diff1 = train_data['Price'].diff().dropna()
adf_stat_d1, adf_p_d1, _, _, _, _ = adfuller(diff1)
print(f"First difference (d=1): ADF p = {adf_p_d1:.6f}", "(NON-STATIONARY)" if adf_p_d1 > 0.05 else "(STATIONARY)")

# Seasonal difference
diff_s = train_data['Price'].diff(252).dropna()
adf_stat_ds, adf_p_ds, _, _, _, _ = adfuller(diff_s)
print(f"Seasonal difference (D=1, s=252): ADF p = {adf_p_ds:.6f}", "(NON-STATIONARY)" if adf_p_ds > 0.05 else "(STATIONARY)")

# Combined
diff_combined = diff1.diff(251).dropna()
adf_stat_combined, adf_p_combined, _, _, _, _ = adfuller(diff_combined)
print(f"Combined (d=1, D=1, s=252): ADF p = {adf_p_combined:.6f}", "(NON-STATIONARY)" if adf_p_combined > 0.05 else "(STATIONARY)")

## Section 6: Fit Multiple SARIMA Models

In [None]:
models = [
    {'name': 'ARIMA(0,1,0)', 'order': (0,1,0), 'seasonal_order': (0,0,0,252), 'seasonal': 'No'},
    {'name': 'SARIMA(0,1,0)(0,1,0,252)', 'order': (0,1,0), 'seasonal_order': (0,1,0,252), 'seasonal': 'Yes'},
    {'name': 'SARIMA(1,1,0)(1,1,0,252)', 'order': (1,1,0), 'seasonal_order': (1,1,0,252), 'seasonal': 'Yes'},
    {'name': 'SARIMA(0,1,1)(0,1,1,252)', 'order': (0,1,1), 'seasonal_order': (0,1,1,252), 'seasonal': 'Yes'},
    {'name': 'SARIMA(1,1,1)(1,1,1,252)', 'order': (1,1,1), 'seasonal_order': (1,1,1,252), 'seasonal': 'Yes'},
]

results = []

print("Fitting SARIMA models...\n")
for i, model_spec in enumerate(models, 1):
    try:
        model = SARIMAX(train_data['Price'],
                       order=model_spec['order'],
                       seasonal_order=model_spec['seasonal_order'],
                       enforce_stationarity=False,
                       enforce_invertibility=False)
        fitted = model.fit(disp=False)
        
        # Forecast
        forecast = fitted.get_forecast(steps=len(test_data))
        forecast_values = forecast.predicted_mean.values
        
        # Metrics
        rmse = np.sqrt(mean_squared_error(test_data['Price'], forecast_values))
        mae = mean_absolute_error(test_data['Price'], forecast_values)
        
        results.append({
            'Model': model_spec['name'],
            'Seasonal': model_spec['seasonal'],
            'AIC': fitted.aic,
            'BIC': fitted.bic,
            'Test RMSE': rmse,
            'Test MAE': mae,
            'Status': '✓'
        })
        print(f"{i}. {model_spec['name']:30s} | AIC: {fitted.aic:8.2f} | RMSE: {rmse:7.2f} | ✓")
    except Exception as e:
        print(f"{i}. {model_spec['name']:30s} | ERROR: {str(e)[:30]}")
        results.append({
            'Model': model_spec['name'],
            'Seasonal': model_spec['seasonal'],
            'AIC': np.nan,
            'BIC': np.nan,
            'Test RMSE': np.nan,
            'Test MAE': np.nan,
            'Status': '✗'
        })

results_df = pd.DataFrame(results)
print("\nModel Comparison:")
print(results_df.to_string(index=False))

## Section 7: Seasonal vs Non-Seasonal Comparison

In [None]:
seasonal_models = results_df[results_df['Seasonal'] == 'Yes']
non_seasonal_models = results_df[results_df['Seasonal'] == 'No']

if len(seasonal_models) > 0 and len(non_seasonal_models) > 0:
    avg_seasonal_rmse = seasonal_models['Test RMSE'].mean()
    avg_non_seasonal_rmse = non_seasonal_models['Test RMSE'].mean()
    
    print(f"Average Test RMSE:")
    print(f"  Non-seasonal models: {avg_non_seasonal_rmse:.2f}")
    print(f"  Seasonal models: {avg_seasonal_rmse:.2f}")
    
    improvement = (avg_non_seasonal_rmse - avg_seasonal_rmse) / avg_non_seasonal_rmse * 100
    if improvement > 0:
        print(f"  Improvement: +{improvement:.2f}%")
    else:
        print(f"  Degradation: {improvement:.2f}%")

## Section 8: Best Model Detailed Analysis

In [None]:
# Find best by RMSE
valid_results = results_df[results_df['Status'] == '✓']
best_idx = valid_results['Test RMSE'].idxmin()
best_model_spec = models[best_idx]

print(f"Best Model: {best_model_spec['name']}")
print(f"\nFitting best model with detailed summary...\n")

best_model = SARIMAX(train_data['Price'],
                    order=best_model_spec['order'],
                    seasonal_order=best_model_spec['seasonal_order'],
                    enforce_stationarity=False,
                    enforce_invertibility=False)
best_fitted = best_model.fit(disp=False)

print(best_fitted.summary())

# Residual analysis
residuals = best_fitted.resid
print(f"\nResidual Statistics:")
print(f"  Mean: {residuals.mean():.6f}")
print(f"  Std Dev: {residuals.std():.6f}")
print(f"  Min: {residuals.min():.2f}")
print(f"  Max: {residuals.max():.2f}")

## Section 9: Visualization

In [None]:
# Create decomposition plot
fig = plt.figure(figsize=(14, 10))

plt.subplot(4, 1, 1)
plt.plot(train_data['Date'], train_data['Price'], label='Original')
plt.title('Gold Price and Seasonal Decomposition (Period=252)')
plt.ylabel('Price')
plt.legend()

plt.subplot(4, 1, 2)
plt.plot(train_data['Date'], trend, label='Trend')
plt.ylabel('Trend')
plt.legend()

plt.subplot(4, 1, 3)
plt.plot(train_data['Date'], seasonal, label='Seasonal')
plt.ylabel('Seasonal')
plt.legend()

plt.subplot(4, 1, 4)
plt.plot(train_data['Date'], residual, label='Residual')
plt.ylabel('Residual')
plt.xlabel('Date')
plt.legend()

plt.tight_layout()
plt.show()

print("✓ Decomposition plot created")

## Section 10: Summary and Recommendations

In [None]:
print("="*70)
print("SUMMARY AND RECOMMENDATIONS")
print("="*70)

print(f"\nKey Findings:")
print(f"  Seasonal Strength: {seasonal_strength*100:.2f}%")
print(f"  Seasonal Period: 252 trading days (annual)")
print(f"  Best Model by RMSE: {best_model_spec['name']}")
print(f"  Test RMSE: {results_df.loc[best_idx, 'Test RMSE']:.2f}")

print(f"\nRecommendations:")
if seasonal_strength < 0.1:
    print(f"  → Weak seasonality (<10%): Non-seasonal models recommended")
elif seasonal_strength < 0.3:
    print(f"  → Moderate seasonality: Consider seasonal models if improvement > 5%")
else:
    print(f"  → Strong seasonality: Seasonal models recommended")

print(f"\nNext Steps:")
print(f"  1. Validate on holdout test set")
print(f"  2. Compare multiple forecasting horizons (1-day to 30-day ahead)")
print(f"  3. Consider ensemble methods combining seasonal and non-seasonal")
print(f"  4. Monitor seasonal patterns over time (they may change)")
print(f"  5. Explore multiple seasonal periods if applicable")