diff --git a/src/esn/deepesn.jl b/src/esn/deepesn.jl index bdd44c08..ed599226 100644 --- a/src/esn/deepesn.jl +++ b/src/esn/deepesn.jl @@ -84,7 +84,7 @@ function DeepESN(train_data::AbstractArray, in_size::Int, res_size::Int; depth:: input_matrix, bias_vector) train_data = train_data[:, (washout + 1):end] - DeepESN(res_size, train_data, nla_type, input_matrix, + return DeepESN(res_size, train_data, nla_type, input_matrix, inner_res_driver, reservoir_matrix, bias_vector, states_type, washout, states) end diff --git a/src/esn/esn.jl b/src/esn/esn.jl index d541c258..8d1854b0 100644 --- a/src/esn/esn.jl +++ b/src/esn/esn.jl @@ -76,7 +76,7 @@ function ESN(train_data::AbstractArray, in_size::Int, res_size::Int; input_matrix, bias_vector) train_data = train_data[:, (washout + 1):end] - ESN(res_size, train_data, nla_type, input_matrix, + return ESN(res_size, train_data, nla_type, input_matrix, inner_res_driver, reservoir_matrix, bias_vector, states_type, washout, states) end @@ -84,7 +84,6 @@ end function (esn::AbstractEchoStateNetwork)(prediction::AbstractPrediction, output_layer::AbstractOutputLayer; last_state=esn.states[:, [end]], kwargs...) - pred_len = prediction.prediction_len return obtain_esn_prediction(esn, prediction, last_state, output_layer; kwargs...) end diff --git a/src/esn/esn_inits.jl b/src/esn/esn_inits.jl index 5d9209ca..43f8749b 100644 --- a/src/esn/esn_inits.jl +++ b/src/esn/esn_inits.jl @@ -237,30 +237,15 @@ julia> res_input = minimal_init(8, 3; p=0.8)# higher p -> more positive signs ``` """ function minimal_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; - sampling_type::Symbol=:bernoulli, weight::Number=T(0.1), irrational::Real=pi, - start::Int=1, p::Number=T(0.5)) where {T <: Number} + sampling_type::Symbol=:bernoulli, kwargs...) where {T <: Number} res_size, in_size = dims - if sampling_type == :bernoulli - layer_matrix = _create_bernoulli(p, res_size, in_size, weight, rng, T) - elseif sampling_type == :irrational - layer_matrix = _create_irrational(irrational, - start, - res_size, - in_size, - weight, - rng, - T) - else - error("""\n - Sampling type not allowed. - Please use one of :bernoulli or :irrational\n - """) - end + f_sample = getfield(@__MODULE__, sampling_type) + layer_matrix = f_sample(rng, T, res_size, in_size; kwargs...) return layer_matrix end -function _create_bernoulli(p::Number, res_size::Int, in_size::Int, weight::Number, - rng::AbstractRNG, ::Type{T}) where {T <: Number} +function bernoulli(rng::AbstractRNG, ::Type{T}, res_size::Int, in_size::Int; + weight::Number=T(0.1), p::Number=T(0.5)) where {T <: Number} input_matrix = DeviceAgnostic.zeros(rng, T, res_size, in_size) for i in 1:res_size for j in 1:in_size @@ -274,9 +259,9 @@ function _create_bernoulli(p::Number, res_size::Int, in_size::Int, weight::Numbe return input_matrix end -function _create_irrational(irrational::Irrational, start::Int, res_size::Int, - in_size::Int, weight::Number, rng::AbstractRNG, - ::Type{T}) where {T <: Number} +function irrational(rng::AbstractRNG, ::Type{T}, res_size::Int, in_size::Int; + irrational::Irrational=pi, start::Int=1, + weight::Number=T(0.1)) where {T <: Number} setprecision(BigFloat, Int(ceil(log2(10) * (res_size * in_size + start + 1)))) ir_string = string(BigFloat(irrational)) |> collect deleteat!(ir_string, findall(x -> x == '.', ir_string)) diff --git a/src/esn/esn_predict.jl b/src/esn/esn_predict.jl index 46e30e95..dc61c3c4 100644 --- a/src/esn/esn_predict.jl +++ b/src/esn/esn_predict.jl @@ -25,7 +25,7 @@ function obtain_esn_prediction(esn, states[:, i] = x end - save_states ? (output, states) : output + return save_states ? (output, states) : output end function obtain_esn_prediction(esn, @@ -55,7 +55,7 @@ function obtain_esn_prediction(esn, states[:, i] = x end - save_states ? (output, states) : output + return save_states ? (output, states) : output end #prediction dispatch on esn @@ -98,11 +98,11 @@ function allocate_outpad(hesn::HybridESN, states_type, out) end function allocate_singlepadding(::AbstractPaddedStates, out) - adapt(typeof(out), zeros(size(out, 1) + 1)) + return adapt(typeof(out), zeros(size(out, 1) + 1)) end function allocate_singlepadding(::StandardStates, out) - adapt(typeof(out), zeros(size(out, 1))) + return adapt(typeof(out), zeros(size(out, 1))) end function allocate_singlepadding(::ExtendedStates, out) - adapt(typeof(out), zeros(size(out, 1))) + return adapt(typeof(out), zeros(size(out, 1))) end diff --git a/src/esn/esn_reservoir_drivers.jl b/src/esn/esn_reservoir_drivers.jl index a5e207af..e74f8f8c 100644 --- a/src/esn/esn_reservoir_drivers.jl +++ b/src/esn/esn_reservoir_drivers.jl @@ -101,11 +101,11 @@ echo state networks (`ESN`). Defaults to 1.0. """ function RNN(; activation_function=fast_act(tanh), leaky_coefficient=1.0) - RNN(activation_function, leaky_coefficient) + return RNN(activation_function, leaky_coefficient) end function reservoir_driver_params(rnn::RNN, args...) - rnn + return rnn end function next_state!(out, rnn::RNN, x, y, W, W_in, b, tmp_array) @@ -113,7 +113,7 @@ function next_state!(out, rnn::RNN, x, y, W, W_in, b, tmp_array) mul!(tmp_array[2], W_in, y) @. tmp_array[1] = rnn.activation_function(tmp_array[1] + tmp_array[2] + b) * rnn.leaky_coefficient - @. out = (1 - rnn.leaky_coefficient) * x + tmp_array[1] + return @. out = (1 - rnn.leaky_coefficient) * x + tmp_array[1] end function next_state!(out, rnn::RNN, x, y, W::Vector, W_in, b, tmp_array) @@ -353,7 +353,7 @@ function obtain_gru_state!(out, variant::FullyGated, gru, x, y, W, W_in, b, tmp_ mul!(tmp_array[7], W_in, y) mul!(tmp_array[8], W, tmp_array[6] .* x) @. tmp_array[9] = gru.activation_function[3](tmp_array[7] + tmp_array[8] + b) - @. out = (1 - tmp_array[3]) * x + tmp_array[3] * tmp_array[9] + return @. out = (1 - tmp_array[3]) * x + tmp_array[3] * tmp_array[9] end #minimal @@ -366,5 +366,5 @@ function obtain_gru_state!(out, variant::Minimal, gru, x, y, W, W_in, b, tmp_arr mul!(tmp_array[5], W, tmp_array[3] .* x) @. tmp_array[6] = gru.activation_function[2](tmp_array[4] + tmp_array[5] + b) - @. out = (1 - tmp_array[3]) * x + tmp_array[3] * tmp_array[6] + return @. out = (1 - tmp_array[3]) * x + tmp_array[3] * tmp_array[6] end diff --git a/src/esn/hybridesn.jl b/src/esn/hybridesn.jl index 366799d1..0c0a945e 100644 --- a/src/esn/hybridesn.jl +++ b/src/esn/hybridesn.jl @@ -119,7 +119,7 @@ function HybridESN(model::KnowledgeModel, train_data::AbstractArray, input_matrix, bias_vector) train_data = train_data[:, (washout + 1):end] - HybridESN(res_size, train_data, model, nla_type, input_matrix, + return HybridESN(res_size, train_data, model, nla_type, input_matrix, inner_res_driver, reservoir_matrix, bias_vector, states_type, washout, states) end diff --git a/src/predict.jl b/src/predict.jl index 18a30bfb..a60fa992 100644 --- a/src/predict.jl +++ b/src/predict.jl @@ -63,7 +63,7 @@ on the learned relationships in the model. """ function Predictive(prediction_data::AbstractArray) prediction_len = size(prediction_data, 2) - Predictive(prediction_data, prediction_len) + return Predictive(prediction_data, prediction_len) end function obtain_prediction(rc::AbstractReservoirComputer, prediction::Generative, diff --git a/src/states.jl b/src/states.jl index 9f78d993..a38726b2 100644 --- a/src/states.jl +++ b/src/states.jl @@ -3,13 +3,13 @@ abstract type AbstractPaddedStates <: AbstractStates end abstract type NonLinearAlgorithm end function pad_state!(states_type::AbstractPaddedStates, x_pad, x) - x_pad = vcat(fill(states_type.padding, (1, size(x, 2))), x) + x_pad[1, :] .= states_type.padding + x_pad[2:end, :] .= x return x_pad end function pad_state!(states_type, x_pad, x) - x_pad = x - return x_pad + return x end #states types