In [None]:
"""
Main forecasting benchmark runner.

This script orchestrates the entire pipeline:
1. Loads and preprocesses data (using Zarr for caching).
2. Sets up the graph structure for the GCN model.
3. Iterates through each model defined in the config.
4. Runs a robust, non-overlapping (stacked) forecast evaluation with weekly retraining.
5. Calculates performance metrics (MAE, RMSE) for each hour in the forecast horizon.
6. Saves the final forecasts and a summary of the results.
"""
import numpy as np
import pandas as pd
import torch
from sklearn.metrics import mean_absolute_error, mean_squared_error
from tqdm import tqdm

from config import CONFIG
from data_handler import create_tabular_features, load_or_create_zarr_cache
from models import (ARIMABaseline, ARIMAModelWrapper, GCN_GRU, NNModelWrapper,
                    QuantileRegressionWrapper)
from utils import build_adjacency_from_coords, normalize_adj


def run_benchmark():
    """Main forecasting benchmark runner."""
    # --- 1. Load Data and Config ---
    conf = CONFIG
    all_data_seq, timestamps = load_or_create_zarr_cache(
        conf["data"]["source_path"], conf["data"]["zarr_path"]
    )
    X_tab, y_tab = create_tabular_features(
        all_data_seq, timestamps, conf["data"]["horizon"]
    )

    device = conf["run"]["device"]
    print(f"Running on device: {device}")

    # --- 2. Setup Experiment ---
    n_turbines = all_data_seq.shape[1]
    coords = np.array([v for _, v in sorted(conf["graph"]["coords"].items())])
    A = build_adjacency_from_coords(coords, k=conf["graph"]["k_neighbors"])
    A_hat = normalize_adj(A)
    horizon = conf["data"]["horizon"]

    test_start_date = pd.to_datetime(conf["evaluation"]["test_start_date"], utc=True)
    test_start_idx = timestamps.get_loc(test_start_date)

    # --- 3. Run Models ---
    results = {}
    all_model_preds, all_model_lower, all_model_upper = {}, {}, {}
    final_trues = None

    for model_name, model_config in conf["models"].items():
        print("\n" + "=" * 50 + f"\nRunning model: {model_name.upper()}\n" + "=" * 50)

        all_preds, all_trues, all_lower, all_upper = [], [], [], []
        forecast_timestamps = []
        last_retrain_idx = -np.inf # Ensure training happens on the first step

        wrapper_class = globals()[model_config["wrapper"]]
        model_instance = wrapper_class(
            data_params=conf["data"], n_turbines=n_turbines, device=device, **model_config
        )

        test_range = range(test_start_idx, len(timestamps) - horizon, horizon)
        for forecast_start_idx in tqdm(test_range, desc=f"Forecasting with {model_name}"):

            # Retraining logic
            if forecast_start_idx >= last_retrain_idx + conf["evaluation"]["retrain_every_hours"]:
                print(f"\nRetraining {model_name} at Timestamp: {timestamps[forecast_start_idx]}")
                last_retrain_idx = forecast_start_idx
                train_end_idx = forecast_start_idx

                # **FIX**: Use .get() for robust access to prevent KeyError
                rolling_window = model_config["training"].get("rolling_window_size", len(timestamps))
                train_start_idx = max(0, train_end_idx - rolling_window)

                if model_config["wrapper"] == "QuantileRegressionWrapper":
                    train_mask = (X_tab.index >= timestamps[train_start_idx]) & (X_tab.index < timestamps[train_end_idx])
                    model_instance.train(X_tab[train_mask], y_tab[train_mask])
                else:
                    # Ensure training indices are valid
                    train_indices = list(range(train_start_idx, max(train_start_idx, train_end_idx - conf["data"]["lookback"] - horizon)))
                    model_instance.train(all_data_seq, train_indices=train_indices, A_hat=A_hat)

            # Prediction logic
            if model_config["wrapper"] == "QuantileRegressionWrapper":
                # Ensure the feature row exists before trying to predict
                if timestamps[forecast_start_idx - 1] in X_tab.index:
                    X_test = X_tab.loc[[timestamps[forecast_start_idx - 1]]]
                    pred = model_instance.predict(X_test)
                else: # Skip if no features available (e.g., due to dropna)
                    continue
            else:
                lookback_start = forecast_start_idx - conf["data"]["lookback"]
                seq_window = all_data_seq[lookback_start:forecast_start_idx]
                # **FIX**: Pass the adjacency matrix `A_hat` to the predict method
                pred = model_instance.predict(seq_window, A_hat=A_hat)

            # Store results
            true_slice = slice(forecast_start_idx, forecast_start_idx + horizon)
            true_values = all_data_seq[true_slice, :, 0]
            all_trues.append(true_values.sum(axis=1))
            forecast_timestamps.extend(timestamps[true_slice])

            if model_config["wrapper"] == "QuantileRegressionWrapper":
                all_preds.append(pred[0.5][0])
                all_lower.append(pred[0.05][0])
                all_upper.append(pred[0.95][0])
            else:
                all_preds.append(pred.sum(axis=1))

        # --- 4. Evaluate and Store Model Results ---
        farm_preds_np = np.array(all_preds)
        farm_trues_np = np.array(all_trues)

        all_model_preds[model_name] = farm_preds_np
        if final_trues is None: final_trues = farm_trues_np

        if all_lower:
            all_model_lower[model_name] = np.array(all_lower)
            all_model_upper[model_name] = np.array(all_upper)

        metrics_per_horizon = {}
        for h in range(horizon):
            mae = mean_absolute_error(farm_trues_np[:, h], farm_preds_np[:, h])
            rmse = np.sqrt(mean_squared_error(farm_trues_np[:, h], farm_preds_np[:, h]))
            metrics_per_horizon[f"h{h+1}_mae"] = mae
            metrics_per_horizon[f"h{h+1}_rmse"] = rmse
        results[model_name] = metrics_per_horizon

    # --- 5. Assemble Final Stacked DataFrame ---
    # **FIX**: More robust and clear way to build the final results DataFrame
    final_index = pd.to_datetime(forecast_timestamps).unique()
    final_df_list = []

    # Ensure actuals are added only once and correctly indexed
    actuals_flat = final_trues.reshape(-1)
    actuals_df = pd.DataFrame({'actual': actuals_flat}, index=final_index[:len(actuals_flat)])
    final_df_list.append(actuals_df)

    for model_name, preds in all_model_preds.items():
        preds_flat = preds.reshape(-1)
        pred_df = pd.DataFrame({model_name: preds_flat}, index=final_index[:len(preds_flat)])
        final_df_list.append(pred_df)

    for model_name, lower in all_model_lower.items():
        lower_flat = lower.reshape(-1)
        df = pd.DataFrame({f"{model_name}_lower": lower_flat}, index=final_index[:len(lower_flat)])
        final_df_list.append(df)
    for model_name, upper in all_model_upper.items():
        upper_flat = upper.reshape(-1)
        df = pd.DataFrame({f"{model_name}_upper": upper_flat}, index=final_index[:len(upper_flat)])
        final_df_list.append(df)

    farm_forecasts_df = pd.concat(final_df_list, axis=1)

    print("\n" + "="*50 + "\nFINAL RESULTS SUMMARY\n" + "="*50)
    summary_df = pd.DataFrame(results).T
    print(summary_df.to_string(float_format="%.2f"))

    return summary_df, farm_forecasts_df


if __name__ == '__main__':
    # This block will only run when the script is executed directly
    final_results, final_forecasts = run_benchmark()
    final_forecasts.to_pickle("final_forecasts.pkl")
    print("\nForecasts saved to 'final_forecasts.pkl'")

Loaded sequence data from Zarr cache: data_cache.zarr
Creating tabular features for tree-based models...
Tabular feature creation complete. Shape: (8706, 14)
Running on device: cuda

Running model: GCN_GRU


Forecasting with gcn_gru:   0%|          | 0/855 [00:00<?, ?it/s]


Retraining gcn_gru at Timestamp: 2022-06-01 00:00:00+00:00
Training NN model...


Forecasting with gcn_gru:   0%|          | 1/855 [02:12<31:26:23, 132.53s/it]


Retraining gcn_gru at Timestamp: 2022-06-08 00:00:00+00:00
Training NN model...


In [None]:
final_forecasts