In [4]:
from ts.models.nbeats import NBeatsG

In [5]:
import numpy as np
import pandas as pd
import plotly.express as px
import torch

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model from checkpoint
model = NBeatsG.load_from_checkpoint("checkpoints/nbeat_m3_run1.ckpt")
model = model.to(device)
model.eval()

df = pd.read_parquet("data/intermediate/m3-monthly_inference_preprocessed.parquet")
X_scaled = df["X"].values

# Step 3: Convert to tensor
X_tensor = torch.tensor(np.stack(X_scaled), dtype=torch.float32, device=device)

# Step 4: Inference
with torch.no_grad():
    yhat_tensor = model(X_tensor)
    yhat_np = yhat_tensor.cpu().numpy()

# Step 5: Build forecast_df
records = []
for i, (uid, row) in enumerate(df.iterrows()):
    y_true = row["y"]
    yhat_scaled = yhat_np.flatten()  # scalers[uid].inverse_transform(yhat_np[i].reshape(-1, 1))
    ds_vals = row["ds"]
    for ds_i, y_i, yhat_i in zip(ds_vals, y_true, yhat_scaled):
        records.append({"unique_id": uid, "ds": ds_i, "y_scaled": y_i, "yhat_scaled": yhat_i})

forecast_df = pd.DataFrame(records)
forecast_df.to_parquet("data/intermediate/m3-monthly_scaled_forecast.parquet")

In [6]:
forecast_df

Unnamed: 0,unique_id,ds,y_scaled,yhat_scaled
0,M1000,1993-03-31,0.939494,0.907305
1,M1000,1993-04-30,0.976422,0.909661
2,M1000,1993-05-31,0.915569,0.906937
3,M1000,1993-06-30,0.934119,0.904528
4,M1000,1993-07-31,0.952323,0.931337
...,...,...,...,...
13291,M999,1993-10-31,0.930218,0.973675
13292,M999,1993-11-30,0.939335,0.946705
13293,M999,1993-12-31,0.895766,0.943014
13294,M999,1994-01-31,0.857894,0.944151


In [33]:
df

Unnamed: 0_level_0,X,y,ds
unique_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
M1000,"[0.31848124, 0.34153956, 0.38557568, 0.3997919...","[0.9394939, 0.9764219, 0.9155687, 0.9341193, 0...","[1993-03-31T00:00:00.000000000, 1993-04-30T00:..."
M1001,"[0.29813406, 0.31244454, 0.33096397, 0.3511668...","[0.9497264, 1.0, 0.9810131, 0.9696956, 0.97507...","[1993-03-31T00:00:00.000000000, 1993-04-30T00:..."
M1002,"[0.49256065, 0.4962279, 0.5298617, 0.53394794,...","[0.8604359, 0.9062238, 0.9224644, 0.94247705, ...","[1993-03-31T00:00:00.000000000, 1993-04-30T00:..."
M1003,"[0.22400206, 0.22979066, 0.24345411, 0.2748198...","[0.9689037, 0.9968367, 1.0, 0.9972406, 0.97300...","[1993-03-31T00:00:00.000000000, 1993-04-30T00:..."
M1004,"[0.6991379, 0.675431, 0.6758621, 0.7137931, 0....","[0.64181036, 0.662931, 0.65818965, 0.6590517, ...","[1993-03-31T00:00:00.000000000, 1993-04-30T00:..."
...,...,...,...
M995,"[0.5009964, 0.5217218, 0.53965724, 0.5557991, ...","[0.94021523, 0.9414109, 0.94938225, 0.9475887,...","[1993-03-31T00:00:00.000000000, 1993-04-30T00:..."
M996,"[0.31983837, 0.3001667, 0.3182576, 0.34433994,...","[0.97198546, 1.0, 0.98594886, 0.97365415, 0.96...","[1993-03-31T00:00:00.000000000, 1993-04-30T00:..."
M997,"[0.13670345, 0.14133143, 0.1593687, 0.14429809...","[0.34413195, 0.31292275, 0.33641866, 0.3091254...","[1993-03-31T00:00:00.000000000, 1993-04-30T00:..."
M998,"[0.5570401, 0.5829394, 0.60431653, 0.6357657, ...","[0.8822199, 0.8256937, 0.93422407, 0.94306266,...","[1993-03-31T00:00:00.000000000, 1993-04-30T00:..."
