# **Forecasting**
---
Run all point forecasts (STGCN, ARIMA) & probabilistic LGBM

**Load packages and modules**

In [1]:
import logging
import numpy as np
import pandas as pd
import torch
from sklearn.metrics import mean_absolute_error, mean_squared_error
from tqdm.notebook import tqdm
import os
from IPython.display import display, clear_output

from config import CONFIG
from data_handler import create_tabular_features_for_window, load_or_create_zarr_cache
from models import (ARIMAModelWrapper, NNModelWrapper, DirectTreeWrapper)
from utils import build_adjacency_from_coords, normalize_adj_torch

**Forecasting pipeline**

In [2]:
def run_forecaster():
    """Main forecasting benchmark runner."""
    # load data + config
    conf = CONFIG
    logging.basicConfig(level=conf['run']['log_level'],
                        format='%(asctime)s - %(levelname)s - %(message)s')
    
    max_workers = os.cpu_count() or 1
    if conf['run'].get('num_workers', 0) > max_workers:
        logging.warning(f"num_workers ({conf['run']['num_workers']}) > available CPUs ({max_workers}). Setting to {max_workers}.")
        conf['run']['num_workers'] = max_workers

    all_data_seq, timestamps = load_or_create_zarr_cache(conf["data"]["source_path"], conf["data"]["zarr_path"])
    
    device = conf["run"]["device"]
    logging.info(f"Running on device: {device}")
    torch.manual_seed(conf['run']['seed'])
    np.random.seed(conf['run']['seed'])

    # set up experimental guidelines
    n_turbines = all_data_seq.shape[1]
    coords = np.array([v for _, v in sorted(conf["graph"]["coords"].items())])
    A = build_adjacency_from_coords(coords, k=conf["graph"]["k_neighbors"])
    A_hat = normalize_adj_torch(A)

    horizon = conf["data"]["horizon"]
    target_idx = conf['data']['target_col_idx']
    test_start_date = pd.to_datetime(conf["evaluation"]["test_start_date"], utc=True)
    test_start_idx = timestamps.get_loc(test_start_date)

    # RUN MODELS
    results = {}
    all_model_preds = {}
    probabilistic_preds = {}
    final_trues = None

    for model_name, model_config in conf["models"].items():
        logging.info("\n" + "="*50 + f"\nRunning model: {model_name.upper()}\n" + "="*50)

        all_preds, all_trues, forecast_timestamps = [], [], []
        if model_config.get('probabilistic'):
            probabilistic_preds[model_name] = {'lower': [], 'upper': []}

        last_retrain_idx = -np.inf

        wrapper_class = globals()[model_config["wrapper"]]
        model_instance = wrapper_class(data_params=conf["data"], run_params=conf["run"], n_turbines=n_turbines, device=device, **model_config)

        test_range = range(test_start_idx, len(timestamps) - horizon, horizon)
        for forecast_start_idx in tqdm(test_range, desc=f"Forecasting with {model_name}"):

            if forecast_start_idx >= last_retrain_idx + conf["evaluation"]["retrain_every_hours"]:
                logging.info(f"\nRetraining {model_name} at Timestamp: {timestamps[forecast_start_idx]}")
                last_retrain_idx = forecast_start_idx
                train_end_idx = forecast_start_idx
                
                rolling_window = model_config["training"].get("rolling_window_size", len(timestamps))
                train_start_idx = max(0, train_end_idx - rolling_window)

                if "TreeWrapper" in model_config["wrapper"]:
                    X_train, y_train = create_tabular_features_for_window(all_data_seq[train_start_idx:train_end_idx], timestamps[train_start_idx:train_end_idx], horizon, target_idx)
                    model_instance.train(X_train, y_train)
                
                elif "ARIMA" in model_config["wrapper"]:
                    farm_power_train = all_data_seq[train_start_idx:train_end_idx, :, target_idx].sum(axis=1)
                    model_instance.train(farm_power_train)
                
                else: # NNModelWrapper
                    train_indices = list(range(train_start_idx, max(train_start_idx, train_end_idx - conf["data"]["lookback"] - horizon)))
                    model_instance.train(all_data_seq, train_indices=train_indices, A_hat=A_hat)

            # --- Prediction Logic ---
            if "TreeWrapper" in model_config["wrapper"]:
                feature_gen_start = forecast_start_idx - conf['data']['lookback'] - 72
                X_test, _ = create_tabular_features_for_window(all_data_seq[feature_gen_start:forecast_start_idx], timestamps[feature_gen_start:forecast_start_idx], horizon, target_idx)
                pred = model_instance.predict(X_test.iloc[[-1]])
            
            elif "ARIMA" in model_config["wrapper"]:
                pred = model_instance.predict()
            
            else: # NNModelWrapper
                lookback_start = forecast_start_idx - conf["data"]["lookback"]
                seq_window = all_data_seq[lookback_start:forecast_start_idx]
                pred = model_instance.predict(seq_window, A_hat=A_hat)

            # --- Store results ---
            true_slice = slice(forecast_start_idx, forecast_start_idx + horizon)
            true_values = all_data_seq[true_slice, :, target_idx]
            all_trues.append(true_values.sum(axis=1))
            forecast_timestamps.extend(timestamps[true_slice])

            if model_config.get('probabilistic') and pred.ndim == 2:
                all_preds.append(pred[:, 1]) 
                probabilistic_preds[model_name]['lower'].append(pred[:, 0])
                probabilistic_preds[model_name]['upper'].append(pred[:, 2])
            elif pred.ndim == 2: # NN model
                all_preds.append(pred.sum(axis=1))
            else: 
                all_preds.append(pred)

        # eval results
        farm_preds_np, farm_trues_np = np.array(all_preds), np.array(all_trues)
        all_model_preds[model_name] = farm_preds_np
        if final_trues is None: final_trues = farm_trues_np
        metrics_per_horizon = {}
        for h in range(horizon):
            mae = mean_absolute_error(farm_trues_np[:, h], farm_preds_np[:, h])
            rmse = np.sqrt(mean_squared_error(farm_trues_np[:, h], farm_preds_np[:, h]))
            metrics_per_horizon[f"h{h+1}_mae"] = mae
            metrics_per_horizon[f"h{h+1}_rmse"] = rmse
        results[model_name] = metrics_per_horizon
        
        clear_output(wait=True)
        summary_df = pd.DataFrame(results).T
        print("--- INTERIM RESULTS SUMMARY ---")
        display(summary_df.style.format("{:.3f}"))

    # save output
    clear_output(wait=True)
    final_index = pd.to_datetime(np.unique(forecast_timestamps))
    actuals_flat = final_trues.flatten()
    df_list = [pd.DataFrame({'actual': actuals_flat}, index=final_index[:len(actuals_flat)])]

    for model_name, preds in all_model_preds.items():
        is_probabilistic = conf['models'][model_name].get('probabilistic', False)
        col_name = f"{model_name}_median" if is_probabilistic else model_name
        df_list.append(pd.DataFrame({col_name: preds.flatten()}, index=final_index[:len(preds.flatten())]))

    for model_name, bounds in probabilistic_preds.items():
        df_list.append(pd.DataFrame({f"{model_name}_lower": np.array(bounds['lower']).flatten()}, index=final_index[:len(actuals_flat)]))
        df_list.append(pd.DataFrame({f"{model_name}_upper": np.array(bounds['upper']).flatten()}, index=final_index[:len(actuals_flat)]))

    farm_forecasts_df = pd.concat(df_list, axis=1).sort_index()
    farm_forecasts_df = farm_forecasts_df[~farm_forecasts_df.index.duplicated(keep='first')]

    logging.info("\n" + "="*50 + "\nFINAL RESULTS SUMMARY\n" + "="*50)
    summary_df = pd.DataFrame(results).T
    print(summary_df.to_string(float_format="%.3f"))

    farm_forecasts_df.to_pickle("forecasts.pkl")
    summary_df.to_csv("forecasts.csv")
    logging.info("\nForecasts saved to 'forecasts.pkl'")
    logging.info("Results summary saved to 'forecasts.csv'")

    return summary_df, farm_forecasts_df

**Run**

In [None]:
if __name__ == '__main__':
    summary_df, farm_forecasts_df = run_forecaster()

2025-09-06 18:39:45,480 - INFO - Running on device: cuda
2025-09-06 18:39:45,484 - INFO - 
Running model: STGCN


Forecasting with stgcn:   0%|          | 0/1460 [00:00<?, ?it/s]

2025-09-06 18:39:45,492 - INFO - 
Retraining stgcn at Timestamp: 2023-01-01 00:00:00+00:00
2025-09-06 18:39:45,492 - INFO - Training NN model...
2025-09-06 18:39:48,064 - INFO - Epoch 1/75, Train Loss: 0.18188, Val Loss: 0.03790, LR: 0.000100
2025-09-06 18:39:49,400 - INFO - Epoch 2/75, Train Loss: 0.08188, Val Loss: 0.03711, LR: 0.000100
2025-09-06 18:39:50,706 - INFO - Epoch 3/75, Train Loss: 0.05949, Val Loss: 0.03103, LR: 0.000100
2025-09-06 18:39:52,032 - INFO - Epoch 4/75, Train Loss: 0.04586, Val Loss: 0.02814, LR: 0.000100
2025-09-06 18:39:53,334 - INFO - Epoch 5/75, Train Loss: 0.03861, Val Loss: 0.02608, LR: 0.000100
2025-09-06 18:39:54,650 - INFO - Epoch 6/75, Train Loss: 0.03480, Val Loss: 0.02565, LR: 0.000100
2025-09-06 18:39:55,972 - INFO - Epoch 7/75, Train Loss: 0.03195, Val Loss: 0.02537, LR: 0.000100
2025-09-06 18:39:57,289 - INFO - Epoch 8/75, Train Loss: 0.03041, Val Loss: 0.02488, LR: 0.000100
2025-09-06 18:39:58,606 - INFO - Epoch 9/75, Train Loss: 0.02907, Val L

In [4]:
farm_forecasts_df

Unnamed: 0,actual,stgcn,lgbm_direct_median,arima,lgbm_direct_lower,lgbm_direct_upper
2023-01-01 00:00:00+00:00,7851.432129,8464.554688,6626.377263,7660.308668,5668.209283,7222.207512
2023-01-01 01:00:00+00:00,8088.687012,8273.955078,6579.910121,7489.318373,5178.847177,7476.828777
2023-01-01 02:00:00+00:00,8076.753418,8328.116211,6553.983274,7434.695952,5550.499707,7496.508944
2023-01-01 03:00:00+00:00,8082.624023,8480.824219,7326.990246,7466.632945,5910.244411,7784.171767
2023-01-01 04:00:00+00:00,7895.537109,8408.533203,6981.456610,7491.369261,5708.713576,7901.231101
...,...,...,...,...,...,...
2023-12-31 19:00:00+00:00,8356.484375,6130.388672,5947.659419,7359.358250,3047.488009,9297.843791
2023-12-31 20:00:00+00:00,9377.143555,6283.792480,5988.003615,7122.937347,2630.668650,9502.796150
2023-12-31 21:00:00+00:00,9948.425781,6494.790527,6196.374619,6916.658855,1983.337696,9584.001618
2023-12-31 22:00:00+00:00,9766.027344,6736.031738,6535.992666,7308.262150,1653.474384,9924.547318
