Skip to content

Commit

Permalink
added utilities file
Browse files Browse the repository at this point in the history
  • Loading branch information
JonathanBechtel committed Jan 3, 2022
1 parent cff3403 commit 1a59fc5
Show file tree
Hide file tree
Showing 4 changed files with 243 additions and 12 deletions.
10 changes: 5 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
long_description = fh.read()

setuptools.setup(
name="keras-beats-jonathan-bechtel",
version="0.0.6",
author="Jonathan Bechtel",
author_email="jonathan@jonathanbech.tel",
description="Lightweight installation of NBeats NN architecture for keras",
name= "keras-beats-jonathan-bechtel",
version= "0.0.9",
author= "Jonathan Bechtel",
author_email= "jonathan@jonathanbech.tel",
description= "Lightweight installation of NBeats NN architecture for keras",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/JonathanBechtel/KerasBeats",
Expand Down
3 changes: 2 additions & 1 deletion src/kerasbeats/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from models import GenericBlock, TrendBlock, SeasonalBlock, NBeats
from .nbeats import GenericBlock, TrendBlock, SeasonalBlock, NBeats
from .utilities import prep_time_series, prep_multiple_time_series
129 changes: 123 additions & 6 deletions src/kerasbeats/nbeats.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
# -*- coding: utf-8 -*-
"""
NBeats model that is articulated at the following paper:
NBeats model that is formalized in the following paper:
https://arxiv.org/abs/1905.10437
"""

import numpy as np
import torch as t
import tensorflow as tf
from tensorflow import keras
import tensorflow.keras.backend as K
from tensorflow.keras import Input
from tensorflow.keras import Model

### DIFFERENT BLOCK LAYERS: GENERIC, SEASONAL, TREND
Expand All @@ -30,7 +26,8 @@ def __init__(self,
num_neurons: int -> How many layers to put into each Dense layer in
the generic block
----
block_layers: int -> How many Dense layers to add to the block"""
block_layers: int -> How many Dense layers to add to the block
"""

# collection of layers in the block
self.layers_ = [keras.layers.Dense(num_neurons, activation = 'relu')
Expand Down Expand Up @@ -178,6 +175,7 @@ def call(self, inputs):
forecast = forecast_harmonics_sin + forecast_harmonics_cos
return backcast, forecast

### CREATES NESTED LAYERS INTO A SINGLE NBEATS LAYER
class NBeats(keras.layers.Layer):
def __init__(self,
model_type = 'generic',
Expand Down Expand Up @@ -278,3 +276,122 @@ def call(self, inputs):
residuals = keras.layers.Subtract()([residuals, backcast])
forecast = keras.layers.Add()([forecast, block_forecast])
return forecast

### BUILDS AND COMPILES
class NBeatsModel():

def __init__(self,
model_type:str = 'generic',
lookback:int = 7,
forecast_size:int = 1,
num_generic_neurons:int = 512,
num_generic_stacks:int = 30,
num_generic_layers:int = 4,
num_trend_neurons:int = 256,
num_trend_stacks:int = 3,
num_trend_layers:int = 4,
num_seasonal_neurons:int = 2048,
num_seasonal_stacks:int = 3,
num_seasonal_layers:int = 4,
num_harmonics:int = 1,
polynomial_term:int = 3,
loss:str = 'mae',
learning_rate:float = 0.001,
batch_size: int = 1024):
"""Model used to create and initialize N-Beats model described in the following paper:
https://arxiv.org/abs/1905.10437
Arguments (default listed in parentheses)
-----------------------------------
model: str -> what model architecture to use. Must be one of ['generic', 'interpretable']
----
lookback: int -> what multiplier of the forecast size you want to use for your training window.
This number will be multiplied by the size of the forecast_size argument to get
your training window size. For example, if your forecast size is 3, and your lookback
is 4, your training window will be 4 * 3 = 12
----
forecast_size: int -> How many steps into the future you want your model to predict.
----
num_generic_neurons: int -> The number of neurons (columns) you want in each Dense layer for the generic block
----
num_generic_stacks: int -> How many generic blocks to connect together
----
num_generic_layers: int -> Within each generic block, how many dense layers do you want each one to have. If
you set this number to 4, and num_generic_neurons to 128, then you will have 4 Dense
layers with 128 neurons in each one
----
num_trend_neurons: int -> Number of neurons to place within each Dense layer in each trend block
----
num_trend_stacks: int -> number of trend blocks to stack on top of
one another
----
num_trend_layers: int -> number of Dense layers inside a trend block
----
num_seasonal_neurons: int -> size of Dense layer in seasonal block
----
num_seasonal_stacks: int -> number of seasonal blocks to stack on top
on top of one another
----
num_seasonal_layers: int -> number of Dense layers inside a seasonal
block
----
num_harmonics: int -> seasonal term to use for seasonal stack
----
polynomial_term: int -> size of polynomial expansion for trend block
----
loss: str -> what loss function to use inside keras. accepts any
regression loss function built into keras. You can find
more info here: https://keras.io/api/losses/regression_losses/
----
learning_rate: float -> learning rate to use when training the model
----
batch_size: int -> batch size to use when training the model
"""
self.model_type = model_type
self.lookback = lookback
self.forecast_size = forecast_size
self.num_generic_neurons = num_generic_neurons
self.num_generic_stacks = num_generic_stacks
self.num_generic_layers = num_generic_layers
self.num_trend_neurons = num_trend_neurons
self.num_trend_stacks = num_trend_stacks
self.num_trend_layers = num_trend_layers
self.num_seasonal_neurons = num_seasonal_neurons
self.num_seasonal_stacks = num_seasonal_stacks
self.num_seasonal_layers = num_seasonal_layers
self.num_harmonics = num_harmonics
self.polynomial_term = polynomial_term
self.loss = loss
self.learning_rate = learning_rate
self.batch_size = batch_size

def build_layer(self):
"""Initializes the Nested NBeats layer from initial parameters"""
self.model_layer = NBeats(**self.__dict__)
return self

def build_model(self):
"""Creates keras model to use for fitting"""
inputs = keras.layers.Input(shape = (self.forecast_size * self.lookback, ), dtype = 'float')
forecasts = self.model_layer(inputs)
self.model = Model(inputs, forecasts)
return self

def fit(self, X, y, **kwargs):
"""Build and fit model"""
self.build_layer()
self.build_model()
self.model.compile(optimizer = keras.optimizers.Adam(self.learning_rate),
loss = [self.loss],
metrics = ['mae', 'mape'])
self.model.fit(X, y, batch_size = self.batch_size, **kwargs)
return self

def predict(self, X, **kwargs):
"""Passes predictions back to original keras layer"""
return self.model.predict(X, **kwargs)


def evaluate(self, y_true, y_pred, **kwargs):
"""Passes predicted and true labels back to the original keras model"""
return self.model.evaluate(y_true, y_pred, **kwargs)
113 changes: 113 additions & 0 deletions src/kerasbeats/utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""
Helper functions for using the NBeatsmodel
"""

from pandas import DataFrame, Series
import numpy as np

class InvalidArgumentError(Exception):
"""Used to validate user input"""
pass

def prep_time_series(data,
lookback:int = 7,
horizon:int = 1) -> (np.ndarray, np.ndarray):
"""
Creates windows and their corresponding labels for each unique time series
in a dataset
E.g. if horizon = 2 and lookback = 3 (default)
Input: [1, 2, 3, 4, 5, 6, 7] -> Output: ([1, 2, 3, 4, 5, 6], [7])
inputs:
data: univariate time series you want to create windows for. Can be
pandas dataframe, numpy array or list
lookback: multiple of forecast horizon that you want to use for
training window
horizon: how far out into the future you want to predict
returns numpy array of shape (len(data) - lookback * horizon + horizon,
lookback * horizon) (training windows)
and numpy array of shape (len(data) - lookback * horizon + horizon,
lookback * horizon)
"""

### convert data into numpy array, if necessary
if type(data) == list:
data = np.array(data)

if type(data) in [DataFrame, Series]:
data = data.values

if data.ndim > 1:
if data.shape[1] > 1:
raise InvalidArgumentError("""Input should be a univariate time
series with only a single column""")

# size of training window
backcast_size = lookback * horizon

# total length of data for training window + horizon
window_step = np.expand_dims(np.arange(backcast_size + horizon),
axis=0)

# creates index values for data
window_indexes = window_step + np.expand_dims(
np.arange(len(data) - (backcast_size + horizon - 1)), axis=0).T

windowed_array = data[window_indexes]

return windowed_array[:, :-horizon], windowed_array[:, -horizon:]

def prep_multiple_time_series(data,
label_col: str,
data_col: str,
lookback: int = 7,
horizon: int = 1):
"""
Creates training windows for time series that are stacked on top of each
other
Example:
inputs: [['ar', 1]
['ar', 2],
['ar', 3],
['br', 5],
['br', 6],
['br', 7]]
outputs: [[1, 2], [[3],
[5, 6]], [7]]
It treats the values associated with 'ar' and 'br' as separate time series
Arguments:
data: pandas DataFrame that has at least two columns, one that are labels
for each unique time series in your dataset, and another that are the time
series values
label_col: the name of the column that labels each time series
data_col: the column that contains the time series values
lookback: what multiple of your horizon you want your training data to be
eg -- a horizon of 2 and lookback of 5 creates a training window of 10
horizon: how far into the future you want to predict
"""
# will be used to contain each unique time series inside the dataset
ts_windows = []
ts_vals = []

# labels for each time series within dataset
unique_ts = data[label_col].unique()

# create windows + labels for each timeseries in the dataset
for label in unique_ts:
query = data[label_col] == label
tmp = data.loc[query, data_col].values
windows, labels = prep_time_series(tmp, lookback, horizon)
ts_windows.append(windows)
ts_vals.append(labels)

return np.vstack(ts_windows), np.vstack(ts_vals)

0 comments on commit 1a59fc5

Please sign in to comment.