Skip to content

Y9008/NBEATS

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

NBEATS

N-BEATS: Neural basis expansion analysis for interpretable time series forecasting

NBEATS is a pytorch based library for deep learning based time series forecasting (https://arxiv.org/pdf/1905.10437v3.pdf) and utilises nbeats-pytorch.

Dependencies: Python >=3.6

Installation

$ pip install NBEATS

Import

from NBEATS import NeuralBeats

Mandatory Parameters:

  • data
  • forecast_length

Basic model with only mandatory parameters can be used to get forecasted values as shown below:

import pandas as pd
from NBEATS import NeuralBeats

data = pd.read_csv('test.csv')   
data = data.values        # (nx1 array)

model = NeuralBeats(data=data, forecast_length=5)
model.fit()
forecast = model.predict()

Optional parameters to the model object

Parameter Default Value
backcast_length 3* forecast_length
path ' ' (path to save intermediate training checkpoint)
checkpoint_name 'NBEATS-checkpoint.th'
mode 'cpu'
batch_size len(data)/10
thetas_dims [4, 8]
nb_blocks_per_stack 3
share_weights_in_stack False
train_percent 0.8
save_model False
hidden_layer_units 128
stack [1,1] (As per the paper- Mapping is as follows -- 1: GENERIC_BLOCK, 2: TREND_BLOCK , 3: SEASONALITY_BLOCK)

Functions

fit()

This is used for training the model. The default value of parameters passed are epoch=25, optimiser=Adam, plot=True, verbose=True

ex:

model.fit(epoch=25,optimiser=torch.optim.AdamW(model.parameters, lr=0.001, betas=(0.9, 0.999), eps=1e-07, weight_decay=0.01, amsgrad=False),plot=False, verbose=True)
predict_data ()

The argument to the method could be empty or a numpy array of length backcast_length x 1 which means if no argument is passed and training data is till month m then prediction will be for month m+1,m+2 and m+3 when forecast_length=3.If forecast is needed for month m+3 onwards then numpy array of backcast_length (3 x forecast_length -This is by default) i.e 9(3 x 3) previous months (m-6 to m+2) needs to be provided to predict for month m+3,m+4 and m+5.

Important Note : Backcast length can be provided as a model argument along with forecast_length eg backcast_length=6,backcast_length=9,backcast_length=12......till backcast_length=21 for forecast_length=3 ,as the paper suggests values between 2 x forecast_length to 7 x forecast_length .The default is 3 x forecast_length .

Returns forecasted values.

save(file) & load(file,optimizer):

Save and load the model after training respectively.

Example: model.save('NBEATS.th') or model.load('NBEATS.th')

DEMO

1: GENERIC_BLOCK and 3: SEASONALITY_BLOCK stacks are used below (stack=[1,3]).Go through the paper for more details.Playing around with the 3 blocks(GENERIC,SEASONALITY and TREND) might improve accuracy.

import pandas as pd
from NBEATS import NeuralBeats
from torch import optim

data = pd.read_csv('test.csv')   
data = data.values # nx1(numpy array)

model=NeuralBeats(data=data,forecast_length=5,stack=[1,3],nb_blocks_per_stack=3,thetas_dims=[3,7])

#or use prebuilt models
#model.load(file='NBEATS.th')


#use customised optimiser with parameters
model.fit(epoch=35,optimiser=optim.AdamW(model.parameters, lr=0.001, betas=(0.9, 0.999), eps=1e-07, weight_decay=0.01, amsgrad=False)) 
#or 
#model.fit()

forecast=model.predict()
#or
#model.predict(predict_data=pred_data) where pred_data is numpy array of size backcast_length*1

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages