Skip to content

MartinuzziFrancesco/NeuralBasisExpansions.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

21 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

NBeats

Build Status Coverage Aqua

Implementation of the NBeats model (paper) in Julia (Flux). To use the package please do the following, as the package is not yet in the general registry:

using Pkg
Pkg.add("https://github.com/MartinuzziFrancesco/NeuralBasisExpansions.jl")

The package is still undergoing heavy testing, expect unexpected behavior.

Full sin example with helper functions is given in the example folder, under readme.jl.

# Model parameters
forecast_length = 5
backcast_length = 2*forecast_length
batch_size = 32
hidden_units = 128
theta_dims = (4, 8)
blocks_per_stack = 3

# Generate and batch the data
data = generate_sine_data(1000, backcast_length, forecast_length)
train_data, test_data = data[1:800], data[801:end]
train_batches = batch_data(train_data, batch_size)
test_batches = batch_data(test_data, batch_size)

# Create the NBeatsNet model
model = NBeatsNet(
    stacks=[generic_basis, trend_basis],
    blocks_stacks=blocks_per_stack,
    forecast_length=forecast_length,
    backcast_length=backcast_length,
    thetas_dim=theta_dims,
    hidden_units=hidden_units
)

# Loss function and optimizer
loss_fn(x, y) = Flux.mse(model(x)[2], y)
optimizer = Flux.ADAM(0.001)

# Training loop
epochs = 50
for epoch in 1:epochs
    Flux.train!(loss_fn, Flux.params(model), train_batches, optimizer)
    train_loss = mean([loss_fn(getindex(batch, 1), getindex(batch, 2)) for batch in train_batches])
    test_loss = mean([loss_fn(getindex(batch, 1), getindex(batch, 2)) for batch in test_batches])
    println("Epoch $epoch: Train Loss = $train_loss, Test Loss = $test_loss")
end

# Forecast using the model (example)
x_test, y_true = test_batches[1]
y_pred = model(x_test)[2]

mse, mae, r_squared = evaluate_predictions(y_true, y_pred)

println("Mean Squared Error: $mse")
println("Mean Absolute Error: $mae")
println("R-squared: $r_squared")

Quick example with random data to test the model

forecast_length = 5
backcast_length = 10
blocks_stacks = 3
thetas_dim = (4, 8)
hidden_units = 256

nbeats_net = NBeatsNet(
    stacks = [trend_basis, seasonality_basis],
    blocks_stacks = blocks_stacks,
    forecast_length = forecast_length,
    backcast_length = backcast_length,
    thetas_dim = thetas_dim,
    share_weights = false,
    hidden_units = hidden_units
)

# Create a batch of input data
batch_size = 3  # Number of instances in the batch
input_data = randn(Float32, backcast_length, batch_size)

backcast_output, forecast_output = nbeats_net(input_data)