In [1]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from scipy.signal import detrend, find_peaks, stft, periodogram
from scipy.fftpack import fft, fftfreq
from statsmodels.tsa.filters.hp_filter import hpfilter
from statsmodels.tsa.stattools import adfuller

from arch import arch_model
from statsmodels.tsa.statespace.sarimax import SARIMAX
import pmdarima as pm

In [2]:
def advanced_detrend(df, column="Close", lamb=14400):
    """
    Uses the Hodrick-Prescott filter to extract the trend.
    """
    cycle, trend = hpfilter(df[column], lamb=lamb)
    df["Trend"] = trend
    df["Cycle"] = cycle
    return df, trend

def preprocess_data(df, column="Open"):
    """
    A simple detrending of a series using scipy's detrend.
    """
    df = df.copy()
    df[column] = detrend(df[column])
    return df

def check_stationarity(series, significance=0.05):
    """
    Performs the Augmented Dickey-Fuller (ADF) test on the series. (Either will use Arma or arima based off stationarity)
    Prints the test statistic and p-value.
    """
    series = series.dropna()
    result = adfuller(series)
    print("ADF Statistic: {:.4f}".format(result[0]))
    print("p-value: {:.4f}".format(result[1]))
    if result[1] < significance:
        print("The series is stationary.")
        return True
    else:
        print("The series is non-stationary.")
        return False

def fourier_analysis(df, dataset, column="Trend"):
    """
    Uses Fourier transform to identify dominant frequencies (seasonality) in the data.
    """
    N = len(df)
    T = 1
    y = df[column].values
    yf = fft(y)
    xf = fftfreq(N, T)[:N//2]

    amplitudes = 2.0 / N * np.abs(yf[:N//2])
    peaks, _ = find_peaks(amplitudes, height=0.01 * max(amplitudes))
    print("Detected peaks:", peaks[:5])
    sorted_indices = np.argsort(amplitudes[peaks])[::-1]  # sort descending
    strongest_peaks = peaks[sorted_indices][:5]

    dominant_periods = 1 / xf[strongest_peaks]
    dominant_amplitudes = amplitudes[strongest_peaks]
    print("Strongest detected seasonal periods (in days) and their amplitudes:")
    for period, amplitude in zip(dominant_periods, dominant_amplitudes):
        print(f"Period: {period:.2f} days, Amplitude: {amplitude:.5f}")

    plt.figure(figsize=(10, 5))
    plt.plot(xf, amplitudes)
    plt.xlabel("Frequency (1/day)")
    plt.ylabel("Amplitude")
    plt.title(f"Fourier Transform - Seasonality Detection ({dataset})")
    plt.grid()
    plt.show()

    return dominant_periods, dominant_amplitudes
def plot_hp_filter_results(df, dataset_name, column="Close", lamb=14400): 
    """
    Function used to plot the different hp filter trends, cycles and original data
    """
    plt.figure(figsize=(12, 6))
    plt.plot(df.index, df[column], label="Original Close Price", linewidth=1)
    plt.plot(df.index, df["Trend"], label=f"HP Trend (lambda={lamb})", linewidth=2, color='red')
    plt.plot(df.index, df["Cycle"], label=f"HP Cycle (lambda={lamb})", linewidth=1, linestyle='--', color='green', alpha=0.7)
    plt.title(f"Hodrick-Prescott Filter Decomposition for {dataset_name} (Lambda = {lamb})")
    plt.xlabel("Date")
    plt.ylabel("Price")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()


datasets = ["../data/CORN-Prices.csv", "../data/GLD-Prices.csv", "../data/SLV-Prices.csv","../data/WEAT-Prices.csv"]
lambda_values = [100, 1600, 14400, 129600] # Test lambda values

for dataset in datasets:
    start, end = "/", "-"
    commodity_name = dataset.split("/")[-1].split(".")[0]
    #print(f"\n--- Processing dataset: {commodity_name} ---")
    if not os.path.exists(dataset):
        print(f"Error: Dataset file not found: {dataset}")
        continue

    df = pd.read_csv(dataset, parse_dates=["Date"])
    df.sort_values("Date", inplace=True)
    df.set_index("Date", inplace=True)


    # for lamb in lambda_values: # looping through different lambda values
    #     print(f"\n--- Lambda: {lamb} ---")
    #     df_processed, cycle = advanced_detrend(df.copy(), lamb=lamb)  # HP Filter
    #     print(df_processed.head())
    #     #plot_hp_filter_results(df_processed, commodity_name, lamb=lamb) # Plot results for each lambda

    # fourier_analysis(df_processed, commodity_name)


In [75]:
def fit_sarima_garch(train_data, seasonal_periods, garch_order=(1, 1)):
    """
    Uses auto arima to fit sarima model with seasonal comps and then GARCH model for dealing with residuals.
    """
    # Fit SARIMA model via auto_arima
    model = pm.auto_arima(
        train_data,
        seasonal=True,
        m=seasonal_periods,
        stepwise=True,
        suppress_warnings=True,
        trace=True
    )
    sarima_order = model.order
    sarima_seasonal_order = model.seasonal_order
    print(f"Fitted SARIMA order: {sarima_order}, seasonal order: {sarima_seasonal_order}")

    # Fit SARIMA using SARIMAX
    sarima_model = SARIMAX(train_data, order=sarima_order, seasonal_order=sarima_seasonal_order)
    sarima_results = sarima_model.fit(disp=0)

    residuals = sarima_results.resid.dropna()
    garch = arch_model(residuals, vol='Garch', p=garch_order[0], q=garch_order[1])
    garch_fit = garch.fit(update_freq=5, disp='off')

    return sarima_results, garch_fit

def forecast_sarima_garch(sarima_model, garch_model, steps):
    """
    Predicts future values using SARIMA GARCH model
    """
    sarima_forecast = sarima_model.get_forecast(steps=steps)
    mean_forecast = sarima_forecast.predicted_mean

    garch_forecast = garch_model.forecast(horizon=steps)
    volatility = np.sqrt(garch_forecast.variance.iloc[-1])

    return mean_forecast, volatility

def analyze_commodity(dataset_path, forecast_steps=30):
    """
    Loads the commodity data, applies detrending, checks for stationarity using the ADF test,
    Runs FFT, fits SARIMA-GARCH model, and forecasts future prices.
    """
    commodity_name = dataset_path.split("/")[-1].split(".")[0]
    df = pd.read_csv(dataset_path, parse_dates=["Date"])
    df = df[df["Date"] >= pd.to_datetime("2020-01-01")] # Filtering only data in last 5 years
    df.sort_values("Date", inplace=True)
    df.set_index("Date", inplace=True)

    # HF Filter
    df_processed, trend = advanced_detrend(df.copy())
    data = df_processed["Trend"]

    print(f"Performing ADF Test on the original {commodity_name} series:")
    is_stationary = check_stationarity(trend)
    if not is_stationary:
        print("Data is non-stationary. auto_arima will apply differencing as needed to achieve stationarity.\n")
    else:
        print("Data is already stationary.\n")

    # Training and testing splits
    train_size = int(len(df_processed) * 0.8)
    train = data[:train_size]
    test = data[train_size:]

    #running fourier on column "Trend"
    dominant_periods, top_amplitudes = fourier_analysis(df_processed, commodity_name) 

    if dominant_periods is None or []:
        # Find best seasonal period using periodogram
        freqs, pxx = periodogram(train, detrend='linear')
        dominant_index = np.argmax(pxx[1:]) + 1  # ignore zero frequency
        dominant_freq = freqs[dominant_index]
        dominant_period = int(1 / dominant_freq)
        print(f"Using dominant seasonal period determined from periodogram: {dominant_period} days\n")
    else:
        dominant_period = int(round(dominant_periods[0]))
        print(f"Using dominant seasonal period from Fourier analysis: {dominant_period} days\n")

    sarima_model, garch_model = fit_sarima_garch(train, seasonal_periods=dominant_period)
    print("Finished fitting SARIMA-GARCH models.\n")

    mean_forecast, volatility = forecast_sarima_garch(sarima_model, garch_model, steps=forecast_steps)

    forecast_index = pd.date_range(start=train.index[-1], periods=forecast_steps + 1, freq='D')[1:]
    predictions = pd.DataFrame({
        'Mean': mean_forecast.values,
        'Lower': mean_forecast - 1.96 * volatility,
        'Upper': mean_forecast + 1.96 * volatility
    }, index=forecast_index)

    # Plot forecasts vs actual data
    plt.figure(figsize=(14, 7))
    plt.plot(train.index, train, label='Training Data')
    plt.plot(test.index, test, label='Actual Values')
    plt.plot(predictions.index, predictions['Mean'], label='SARIMA-GARCH Forecast')
    plt.fill_between(predictions.index, predictions['Lower'], predictions['Upper'], color='gray', alpha=0.2)
    plt.title(f'{commodity_name} Price Forecast with Volatility Bands')
    plt.legend()
    plt.show()

    # GARCH diagnostics plot
    garch_model.plot(annualize='D')
    plt.suptitle(f'GARCH Model Diagnostics - {commodity_name}')
    plt.show()

In [None]:
datasets = ["../data/CORN-Prices.csv", "../data/GLD-Prices.csv", "../data/SLV-Prices.csv","../data/WEAT-Prices.csv"]
for dataset in datasets:
    analyze_commodity(dataset, forecast_steps=30)