In [1]:
import pandas as pd

df = pd.read_parquet("data/intermediate/m3-monthly_scaled.parquet")

In [2]:
from ts.models.nbeats import NBeatsG

horizon = 12  # <-- FORECAST HORIZON
input_size = horizon * 5

model = NBeatsG(
    input_size=input_size,
    horizon=horizon,
)

In [3]:
import torch

from ts.preprocess.dataloader import TSPreprocessor, UnivariateTSDataModule

batch_size = 512
num_workers = 24
step_size = 6

# Initialize preprocessor
preprocessor = TSPreprocessor(
    df=df,
    input_size=input_size,
    horizon=horizon,
    target_col="y",
    train_split=0.7,
    val_split=0.15,
    normalize=True,
    scaler_type="minmax",
    split_type="vertical",
    step_size=step_size,
    cache_dir=".",
    use_cache=False,
    persist_scaler=True,
    experiment_name="my_experiment",
)
# Initialize DataModule
ds = UnivariateTSDataModule(
    preprocessor=preprocessor,
    batch_size=batch_size,
    num_workers=num_workers,
    pin_memory=True,
    prefetch_factor=2,
    persistent_workers=True,
    gpu_preload=False,
)

Processing series: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 1428/1428 [00:00<00:00, 2384.83it/s]


[(array([0.14181685, 0.15967393, 0.1302011 , 0.15585995, 0.12066579,
         0.12378645, 0.16366172, 0.18862677, 0.20891118, 0.22902226,
         0.22815537, 0.17024946, 0.19798875, 0.23821092, 0.26109576,
         0.31761432, 0.35957003, 0.3514216 , 0.3127601 , 0.39597797,
         0.32298875, 0.3193481 , 0.26248264, 0.28068662, 0.21688628,
         0.22226071, 0.19694853, 0.1967752 , 0.19729519, 0.18654633,
         0.19209409, 0.19278765, 0.20353675, 0.25814843, 0.29559636,
         0.3009708 , 0.14476418, 0.17406392, 0.08165741, 0.064147  ,
         0.0473299 , 0.07090855, 0.03484726, 0.        , 0.00260043,
         0.02080441, 0.02756596, 0.07194877, 0.08564496, 0.18013191,
         0.17822456, 0.1510055 , 0.18325233, 0.21879315, 0.16660881,
         0.23300982, 0.22624826, 0.22070026, 0.25901532, 0.27219152],
        dtype=float32),
  array([0.2772193 , 0.27791262, 0.3184812 , 0.34153962, 0.38557577,
         0.39979196, 0.40690017, 0.41522193, 0.40083218, 0.41712904,
         

In [4]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    save_top_k=1,
    monitor="val_smape",
    mode="min",
    dirpath="checkpoints/",
    filename="nbeat_m3_run1",
)

trainer = pl.Trainer(
    logger=TensorBoardLogger("logs", name="nbeat_m3_run1"),
    max_epochs=200,  # Short run for testing
    accelerator="auto",
    precision="16-mixed",
    gradient_clip_val=1.0,
    callbacks=[EarlyStopping("val_smape", patience=10, verbose=False), checkpoint_callback],
    # profiler=profiler,
    accumulate_grad_batches=4,
    # log_every_n_steps=1
    # strategy="ddp_notebook"
    # strategy="ddp_notebook",
    # devices=1,
    # num_nodes=2,
)

trainer.fit(model, ds)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/pranav-pc/projects/ts/.venv/lib/python3.12/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /home/pranav-pc/projects/ts/nbs/pipeline/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type                                 | Params | Mode 
-------------------------------------------------------------------------
0 | stacks  | ModuleList                           | 61.9 M | train
1 | loss_fn | MSELoss                              | 0      | train
2 | smape   | SymmetricMeanAbsolutePercentageError | 0      | train
3 | mase    | MASE                                 | 0      | train
4 | owa     | OWA                                  | 0      | train
-------------------------------------------------------------------------
61.9 M    Trainable params

Sanity Checking: |                                                                                            …

/home/pranav-pc/projects/ts/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (15) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

In [5]:
print(trainer.test(model, ds))
print(trainer.validate(model, ds))

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |                                                                                                    …

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


[{'test_loss': 0.020553255453705788, 'test_smape': 0.29276177287101746, 'test_mase': 0.00010583102994132787, 'test_owa': 0.8816570043563843}]


Validation: |                                                                                                 …

[{'val_loss': 0.01787063106894493, 'val_smape': 0.30310124158859253, 'val_mase': 0.00028084401856176555, 'val_owa': 0.8639543652534485}]
