In [1]:
from datasetsforecast.m4 import M4

df = M4().load("../data", group="Monthly")[0]
df.sort_values(["unique_id", "ds"], inplace=True)
df

Unnamed: 0,unique_id,ds,y
0,M1,1,8000.0
1,M1,2,8350.0
2,M1,3,8570.0
3,M1,4,7700.0
4,M1,5,7080.0
...,...,...,...
11246406,M9999,83,4200.0
11246407,M9999,84,4300.0
11246408,M9999,85,3800.0
11246409,M9999,86,4400.0


In [2]:
df.groupby("unique_id").apply(len).describe()

  df.groupby('unique_id').apply(len).describe()


count    48000.000000
mean       234.300229
std        137.406295
min         60.000000
25%        100.000000
50%        220.000000
75%        324.000000
max       2812.000000
dtype: float64

In [3]:
from ts.preprocess.dataloader import UnivariateTSDataModule

In [4]:
input_size = 48
horizon = 12
batch_size = 512 * 2
num_workers = 20

ds = UnivariateTSDataModule(
    df=df,
    input_size=input_size,
    horizon=horizon,
    batch_size=batch_size,
    num_workers=num_workers,
    train_split=0.7,
    val_split=0.15,
    normalize=True,
    scaler_type="minmax",
    split_type="vertical",
    step_size=1,
)

In [5]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

from ts.model.nbeats import NBeatsG

In [None]:
# Example trainer setup (without full NBeatsG for brevity)
trainer = pl.Trainer(
    max_epochs=200,  # Short run for testing
    accelerator="auto",
    logger=TensorBoardLogger("logs", name="nbeatsg_m3_monthly"),
    callbacks=[EarlyStopping("val_loss", patience=10, verbose=False)],
)
model = NBeatsG(input_size=input_size, horizon=horizon)
trainer.fit(model, ds)

INFO:pytorch_lightning.utilities.rank_zero:You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA GeForce RTX 4070 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:ts.preprocess.dataloader:Train windows: 5891975, Val windows: 1241277, Test windows: 1281159
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_D

Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

Training: |                                               | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

In [None]:
trainer.test(model, ds)

In [None]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import torch
from plotly.subplots import make_subplots

# Function to forecast and plot in a 3x2 grid without legend


def forecast_and_plot_grid(trainer, model, data_module, num_series=6):
    # Ensure model is in evaluation mode
    model.eval()

    # Get all unique_ids from the DataFrame
    unique_ids = data_module.df["unique_id"].unique()
    # Randomly select 6 unique_ids (or fewer if less available)
    selected_ids = np.random.choice(
        unique_ids, size=min(num_series, len(unique_ids)), replace=False
    )

    # Prepare test data for selected series
    test_data = data_module.df[data_module.df["unique_id"].isin(selected_ids)]
    grouped = test_data.groupby("unique_id")

    # Create figure with 3x2 grid
    fig = make_subplots(
        rows=3,
        cols=2,
        subplot_titles=[f"Series: {uid}" for uid in selected_ids],
        vertical_spacing=0.15,
        horizontal_spacing=0.1,
    )

    # Perform forecasting and plotting
    for i, unique_id in enumerate(selected_ids):
        series = grouped.get_group(unique_id)["y"].values
        series_len = len(series)

        if series_len < data_module.input_size + data_module.horizon:
            print(f"Skipping {unique_id}: too short for forecasting")
            continue

        # Normalize series if data_module uses normalization
        if data_module.normalize:
            scaler = data_module.scalers.get(unique_id)
            if scaler:
                series_normalized = scaler.transform(series.reshape(-1, 1)).flatten()
            else:
                series_normalized = series
        else:
            series_normalized = series

        # Generate last input window for forecasting
        last_input = series_normalized[
            -data_module.input_size - data_module.horizon : -data_module.horizon
        ]
        x = torch.tensor(last_input, dtype=torch.float32).unsqueeze(0).to(model.device)

        # Forecast
        with torch.no_grad():
            y_hat = model(x).cpu().numpy().flatten()

        # Inverse transform predictions and actual values
        if data_module.normalize and scaler:
            y_hat_denorm = scaler.inverse_transform(y_hat.reshape(-1, 1)).flatten()
            y_true_denorm = series[-data_module.horizon :]
        else:
            y_hat_denorm = y_hat
            y_true_denorm = series[-data_module.horizon :]

        # Time indices for plotting
        time_indices = pd.date_range(
            start=grouped.get_group(unique_id)["ds"].iloc[-data_module.horizon],
            periods=data_module.horizon,
            freq="M",
        )

        # Determine row and column (0-based indexing, converting to 1-based for Plotly)
        row = (i // 2) + 1
        col = (i % 2) + 1

        # Plot actual values
        fig.add_trace(
            go.Scatter(
                x=time_indices,
                y=y_true_denorm,
                mode="lines+markers",
                line=dict(color="blue"),
                showlegend=False,
            ),  # No legend
            row=row,
            col=col,
        )

        # Plot predicted values
        fig.add_trace(
            go.Scatter(
                x=time_indices,
                y=y_hat_denorm,
                mode="lines+markers",
                line=dict(color="red", dash="dash"),
                showlegend=False,
            ),  # No legend
            row=row,
            col=col,
        )

    # Update layout
    fig.update_layout(
        height=900,
        width=800,  # Fixed size for 3x2 grid
        title_text="N-BEATS-G Forecasting Results (3x2 Grid)",
        showlegend=False,  # Remove legend entirely
    )
    fig.update_yaxes(title_text="Value")
    fig.update_xaxes(title_text="Date")

    # Show plot
    fig.show()


# Example usage (assuming trainer, model, and data_module are defined)
# forecast_and_plot_grid(trainer, model, data_module, num_series=6)