In [24]:
import pandas as pd

df = pd.read_parquet("data/intermediate/m3-monthly_scaled.parquet")

In [25]:
from ts.models.nbeats import NBeatsG

horizon = 12  # <-- FORECAST HORIZON
input_size = horizon * 5

model = NBeatsG(
    input_size=input_size,
    horizon=horizon,
)

In [26]:
import torch

from ts.preprocess.dataloader import UnivariateTSDataModule, UnivariateTSDataset

# torch.serialization.add_safe_globals([UnivariateTSDataset])
batch_size = 512
num_workers = 24
step_size = 6
ds = UnivariateTSDataModule(
    # Core data and modeling parameters
    df=df,
    input_size=input_size,
    horizon=horizon,
    step_size=step_size,
    target_col="y_scaled",
    # DataLoader settings
    batch_size=batch_size,
    num_workers=num_workers,
    pin_memory=True,
    prefetch_factor=2,
    persistent_workers=True,
    gpu_preload=False,
    # Data splitting
    train_split=0.7,
    val_split=0.15,
    split_type="vertical",
    # Normalization
    normalize=False,
    scaler_type="minmax",
    persist_scaler=True,
    # Caching
    use_cache=False,
    cache_dir="cache/",
    # Experiment
    experiment_name="nbeat_m3_run1",
)

In [30]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    save_top_k=1,
    monitor="val_smape",
    mode="min",
    dirpath="checkpoints/",
    filename="nbeat_m3_run1",
)

trainer = pl.Trainer(
    logger=TensorBoardLogger("logs", name="nbeat_m3_run1"),
    max_epochs=200,  # Short run for testing
    accelerator="auto",
    precision="16-mixed",
    gradient_clip_val=1.0,
    callbacks=[EarlyStopping("val_smape", patience=10, verbose=False), checkpoint_callback],
    # profiler=profiler,
    accumulate_grad_batches=4,
    # log_every_n_steps=1
    # strategy="ddp_notebook"
    # strategy="ddp_notebook",
    # devices=1,
    # num_nodes=2,
)

trainer.fit(model, ds)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)


Checkpoint directory /home/pranav-pc/projects/ts/nbs/pipeline/checkpoints exists and is not empty.

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type                                 | Params | Mode 
-------------------------------------------------------------------------
0 | stacks  | ModuleList                           | 61.9 M | train
1 | loss_fn | MSELoss                              | 0      | train
2 | smape   | SymmetricMeanAbsolutePercentageError | 0      | train
3 | mase    | MASE           

Sanity Checking: |                                                                                            …


The number of training batches (16) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.



Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

In [31]:
print(trainer.test(model, ds))
print(trainer.validate(model, ds))

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |                                                                                                    …

[{'test_loss': 0.010228564962744713, 'test_smape': 0.2536330819129944, 'test_mase': 0.00013090827269479632, 'test_owa': 0.6189939975738525}]


IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |                                                                                                 …

[{'val_loss': 0.010317462496459484, 'val_smape': 0.26534244418144226, 'val_mase': 0.00014688160445075482, 'val_owa': 0.5794956684112549}]
