In [1]:
# | default_exp train.nbeats

In [2]:
import torch
import wandb
from pytorch_lightning.profilers import PyTorchProfiler

In [3]:
import pandas as pd

# from datasetsforecast.m3 import M3
# from datasetsforecast.m4 import M4
# from datasetsforecast.m5 import M5

# df = pd.concat(
#     [
#         M3().load("../data", group="Monthly")[0],
#         M4().load("../data", group="Monthly")[0],
#         M4().load("../data", group="Weekly")[0],
#         M4().load("../data", group="Daily")[0],
#         M5().load("../data")[0],
#     ]
# )

# # Ensure ds is a datetime object
# df["ds"] = pd.to_datetime(df["ds"], errors="coerce")

# # Sort values
# df.sort_values(["unique_id", "ds"], inplace=True)

# # Convert ds to an integer based on sorted order within each unique_id
# df["ds"] = df.groupby("unique_id")["ds"].rank(method="dense").astype(int)

# # Save as parquet
# df.to_parquet("mid-range-forecast-data-M3-4-5.parquet", index=False)

In [4]:
df = pd.read_parquet("mid-range-forecast-data-M3-4-5.parquet")

In [5]:
df.unique_id.nunique()

83076

In [6]:
df.groupby("unique_id").apply(len).describe()

  df.groupby("unique_id").apply(len).describe()


count    83076.000000
mean       836.093794
std        880.332992
min         60.000000
25%        193.000000
50%        335.000000
75%       1557.000000
max       9933.000000
dtype: float64

In [7]:
from ts.models.nbeats import NBeatsG

horizon = 12
input_size = horizon * 5

model = NBeatsG(input_size=input_size, horizon=horizon)

In [8]:
# | export
import torch

In [9]:
# | export
from ts.preprocess.dataloader import UnivariateTSDataModule

batch_size = 512 * 10
num_workers = 24
step_size = 6

ds = UnivariateTSDataModule(
    df=df,
    input_size=input_size,
    horizon=horizon,
    batch_size=batch_size,
    num_workers=num_workers,
    train_split=0.7,
    val_split=0.15,
    normalize=True,
    scaler_type="minmax",
    split_type="vertical",
    step_size=step_size,
    prefetch_factor=4,
)

In [10]:
# | export
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

In [15]:
# | export
# Example trainer setup (without full NBeatsG for brevity)
from lightning.pytorch.loggers import WandbLogger

wandb_logger = WandbLogger(
    project="shortterm-ts-global-forecast",
    name=f"model=NBeatsG.ds=M5",
)
wandb_logger.watch(model, log="all")

profiler = PyTorchProfiler(
    profile_memory=True,  # Track GPU memory
    record_shapes=True,
    with_stack=True,  # Track CPU memory (if supported)
)

trainer = pl.Trainer(
    logger=wandb_logger,
    max_epochs=200,  # Short run for testing
    accelerator="auto",
    precision="16-mixed",
    gradient_clip_val=1.0,
    # logger=TensorBoardLogger("logs", name="nbeatsg_m5"),
    callbacks=[EarlyStopping("val_smape", patience=10, verbose=False)],
    # profiler=profiler,
    accumulate_grad_batches=4,
    # strategy="ddp_notebook"
)


# trainer.fit(model, ds,)

/home/pranav-pc/projects/ts/.venv/lib/python3.12/site-packages/lightning/pytorch/loggers/wandb.py:397: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


ValueError: You can only call `wandb.watch` once per model.  Pass a new instance of the model if you need to call wandb.watch again in your code.

In [12]:
# | export
trainer.test(model, ds)
wandb.finish()

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |                                                | 0/? [00:00<?, ?it/s]

0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▃▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇██████
test_loss,▁
test_mase,▁
test_owa,▁
test_smape,▁
train_loss_epoch,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss_step,█▄▅▁▆▁▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇█
val_loss,▄█▅▂▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▂▁▁
val_mase,▆█▆▃▂▁▁▁▁▁▁▁▁▁▁▁▃▃▂▃▂▂

0,1
epoch,22.0
test_loss,3e-05
test_mase,0.0
test_owa,0.88232
test_smape,0.02866
train_loss_epoch,4e-05
train_loss_step,0.0001
trainer/global_step,23188.0
val_loss,3e-05
val_mase,0.0


In [None]:
trainer.validate(model, ds)

In [None]:
# | export
# trainer.save_checkpoint("SHORT-TERM-FORECAST-MODEL-NBEATSG(60-12).ckpt")

In [30]:
## Inference

In [11]:
from ts.models.nbeats import NBeatsG

# horizon = 12
# input_size = horizon * 5

# model = NBeatsG(input_size=input_size, horizon=horizon)

# Load the model from checkpoint
model = NBeatsG.load_from_checkpoint("SHORT-TERM-FORECAST-MODEL-NBEATSG(60-12).ckpt")

# If needed, load it into a Trainer to resume training or inference
# from pytorch_lightning import Trainer

# trainer = Trainer()
# trainer.validate(model,ds)  # Run validation
# trainer.test(model,ds);  # Run testing

In [16]:
ds.setup()



In [None]:
import pandas as pd
import plotly.express as px
import torch

model.eval()

mse_values = []

for x, y in ds.test_dataloader():
    y_hat = model(x)  # Get predictions
    errors = torch.nn.functional.mse_loss(y_hat, y, reduction="none")  # Compute per-sample MSE
    mse_values.extend(errors.detach().view(-1).cpu().numpy())  # Detach, flatten & move to CPU
    break  # Only process first batch

# Convert to DataFrame
df_error = pd.DataFrame({"MSE": mse_values})

# Create violin plot
fig = px.violin(df_error, y="MSE", box=True, points="all", title="Distribution of MSE")
fig.show()

In [14]:
df_error

Unnamed: 0,MSE
0,1.149312e-04
1,1.626208e-09
2,2.115133e-05
3,4.014260e-06
4,1.002962e-05
...,...
61435,1.559591e-05
61436,4.726880e-06
61437,1.703628e-05
61438,1.761820e-06


In [None]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import torch
from plotly.subplots import make_subplots
from sklearn.preprocessing import MinMaxScaler


def forecast_and_plot_grid(model, data_module, num_series=6):
    # Ensure model is in evaluation mode
    model.eval()

    # Get all unique_ids from the DataFrame
    unique_ids = data_module.df["unique_id"].unique()

    # Randomly select series (or use all if fewer available)
    selected_ids = np.random.choice(
        unique_ids, size=min(num_series, len(unique_ids)), replace=False
    )

    # Prepare test data for selected series
    test_data = data_module.df[data_module.df["unique_id"].isin(selected_ids)]
    grouped = test_data.groupby("unique_id")

    # Create figure with 3x2 grid
    fig = make_subplots(
        rows=3,
        cols=2,
        subplot_titles=[f"Series: {uid}" for uid in selected_ids],
        vertical_spacing=0.15,
        horizontal_spacing=0.1,
    )

    # Perform forecasting and plotting
    for i, unique_id in enumerate(selected_ids):
        series_df = grouped.get_group(unique_id)
        series = series_df["y"].values  # Raw series
        series_len = len(series)

        if series_len < data_module.input_size + data_module.horizon:
            print(f"Skipping {unique_id}: too short for forecasting")
            continue

        # MinMax scaling on the entire series
        scaler = MinMaxScaler(feature_range=(0, 1))
        series_scaled = scaler.fit_transform(series.reshape(-1, 1)).flatten()

        # Generate last input window for forecasting
        last_input = series_scaled[
            -data_module.input_size - data_module.horizon : -data_module.horizon
        ]
        x = torch.tensor(last_input, dtype=torch.float32).unsqueeze(0).to(model.device)

        # Forecast
        with torch.no_grad():
            y_hat = model(x).cpu().numpy().flatten()

        # Use MinMax transformed series, y_hat is already in model scale
        full_time_indices = series_df["ds"].values  # Full series timestamps
        forecast_time_indices = full_time_indices[
            -data_module.horizon :
        ]  # Last horizon timestamps

        # Determine row and column (Plotly uses 1-based indexing)
        row = (i // 2) + 1
        col = (i % 2) + 1

        # Plot MinMax transformed actual series
        fig.add_trace(
            go.Scatter(
                x=full_time_indices,
                y=series_scaled,
                mode="lines",
                line=dict(color="blue"),
                name=f"Actual {unique_id}",
                showlegend=False,
            ),
            row=row,
            col=col,
        )

        # Plot predicted values for the last horizon
        fig.add_trace(
            go.Scatter(
                x=forecast_time_indices,
                y=y_hat,  # Directly using model output
                mode="lines",
                line=dict(color="red", dash="dash"),
                name=f"Predicted {unique_id}",
                showlegend=False,
            ),
            row=row,
            col=col,
        )

    # Update layout
    fig.update_layout(
        height=900,
        width=800,
        title_text="Forecasting: MinMax Scaled Series with Predictions (3x2 Grid)",
        showlegend=False,
    )
    fig.update_yaxes(title_text="Scaled Value (0-1)")
    fig.update_xaxes(title_text="Date")

    # Show plot
    fig.show()


# Example usage
forecast_and_plot_grid(model, ds, num_series=6)

In [25]:
# trainer.validate(model,ds)

<ts.preprocess.dataloader.UnivariateTSDataModule at 0x7fde2c9e7320>