Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/esn/deepesn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ function DeepESN(train_data::AbstractArray, in_size::Int, res_size::Int; depth::
input_matrix, bias_vector)
train_data = train_data[:, (washout + 1):end]

DeepESN(res_size, train_data, nla_type, input_matrix,
return DeepESN(res_size, train_data, nla_type, input_matrix,
inner_res_driver, reservoir_matrix, bias_vector, states_type, washout,
states)
end
3 changes: 1 addition & 2 deletions src/esn/esn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,14 @@ function ESN(train_data::AbstractArray, in_size::Int, res_size::Int;
input_matrix, bias_vector)
train_data = train_data[:, (washout + 1):end]

ESN(res_size, train_data, nla_type, input_matrix,
return ESN(res_size, train_data, nla_type, input_matrix,
inner_res_driver, reservoir_matrix, bias_vector, states_type, washout,
states)
end

function (esn::AbstractEchoStateNetwork)(prediction::AbstractPrediction,
output_layer::AbstractOutputLayer; last_state=esn.states[:, [end]],
kwargs...)
pred_len = prediction.prediction_len
return obtain_esn_prediction(esn, prediction, last_state, output_layer;
kwargs...)
end
Expand Down
31 changes: 8 additions & 23 deletions src/esn/esn_inits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,30 +237,15 @@ julia> res_input = minimal_init(8, 3; p=0.8)# higher p -> more positive signs
```
"""
function minimal_init(rng::AbstractRNG, ::Type{T}, dims::Integer...;
sampling_type::Symbol=:bernoulli, weight::Number=T(0.1), irrational::Real=pi,
start::Int=1, p::Number=T(0.5)) where {T <: Number}
sampling_type::Symbol=:bernoulli, kwargs...) where {T <: Number}
res_size, in_size = dims
if sampling_type == :bernoulli
layer_matrix = _create_bernoulli(p, res_size, in_size, weight, rng, T)
elseif sampling_type == :irrational
layer_matrix = _create_irrational(irrational,
start,
res_size,
in_size,
weight,
rng,
T)
else
error("""\n
Sampling type not allowed.
Please use one of :bernoulli or :irrational\n
""")
end
f_sample = getfield(@__MODULE__, sampling_type)
layer_matrix = f_sample(rng, T, res_size, in_size; kwargs...)
return layer_matrix
end

function _create_bernoulli(p::Number, res_size::Int, in_size::Int, weight::Number,
rng::AbstractRNG, ::Type{T}) where {T <: Number}
function bernoulli(rng::AbstractRNG, ::Type{T}, res_size::Int, in_size::Int;
weight::Number=T(0.1), p::Number=T(0.5)) where {T <: Number}
input_matrix = DeviceAgnostic.zeros(rng, T, res_size, in_size)
for i in 1:res_size
for j in 1:in_size
Expand All @@ -274,9 +259,9 @@ function _create_bernoulli(p::Number, res_size::Int, in_size::Int, weight::Numbe
return input_matrix
end

function _create_irrational(irrational::Irrational, start::Int, res_size::Int,
in_size::Int, weight::Number, rng::AbstractRNG,
::Type{T}) where {T <: Number}
function irrational(rng::AbstractRNG, ::Type{T}, res_size::Int, in_size::Int;
irrational::Irrational=pi, start::Int=1,
weight::Number=T(0.1)) where {T <: Number}
setprecision(BigFloat, Int(ceil(log2(10) * (res_size * in_size + start + 1))))
ir_string = string(BigFloat(irrational)) |> collect
deleteat!(ir_string, findall(x -> x == '.', ir_string))
Expand Down
10 changes: 5 additions & 5 deletions src/esn/esn_predict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ function obtain_esn_prediction(esn,
states[:, i] = x
end

save_states ? (output, states) : output
return save_states ? (output, states) : output
end

function obtain_esn_prediction(esn,
Expand Down Expand Up @@ -55,7 +55,7 @@ function obtain_esn_prediction(esn,
states[:, i] = x
end

save_states ? (output, states) : output
return save_states ? (output, states) : output
end

#prediction dispatch on esn
Expand Down Expand Up @@ -98,11 +98,11 @@ function allocate_outpad(hesn::HybridESN, states_type, out)
end

function allocate_singlepadding(::AbstractPaddedStates, out)
adapt(typeof(out), zeros(size(out, 1) + 1))
return adapt(typeof(out), zeros(size(out, 1) + 1))
end
function allocate_singlepadding(::StandardStates, out)
adapt(typeof(out), zeros(size(out, 1)))
return adapt(typeof(out), zeros(size(out, 1)))
end
function allocate_singlepadding(::ExtendedStates, out)
adapt(typeof(out), zeros(size(out, 1)))
return adapt(typeof(out), zeros(size(out, 1)))
end
10 changes: 5 additions & 5 deletions src/esn/esn_reservoir_drivers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,19 +101,19 @@ echo state networks (`ESN`).
Defaults to 1.0.
"""
function RNN(; activation_function=fast_act(tanh), leaky_coefficient=1.0)
RNN(activation_function, leaky_coefficient)
return RNN(activation_function, leaky_coefficient)
end

function reservoir_driver_params(rnn::RNN, args...)
rnn
return rnn
end

function next_state!(out, rnn::RNN, x, y, W, W_in, b, tmp_array)
mul!(tmp_array[1], W, x)
mul!(tmp_array[2], W_in, y)
@. tmp_array[1] = rnn.activation_function(tmp_array[1] + tmp_array[2] + b) *
rnn.leaky_coefficient
@. out = (1 - rnn.leaky_coefficient) * x + tmp_array[1]
return @. out = (1 - rnn.leaky_coefficient) * x + tmp_array[1]
end

function next_state!(out, rnn::RNN, x, y, W::Vector, W_in, b, tmp_array)
Expand Down Expand Up @@ -353,7 +353,7 @@ function obtain_gru_state!(out, variant::FullyGated, gru, x, y, W, W_in, b, tmp_
mul!(tmp_array[7], W_in, y)
mul!(tmp_array[8], W, tmp_array[6] .* x)
@. tmp_array[9] = gru.activation_function[3](tmp_array[7] + tmp_array[8] + b)
@. out = (1 - tmp_array[3]) * x + tmp_array[3] * tmp_array[9]
return @. out = (1 - tmp_array[3]) * x + tmp_array[3] * tmp_array[9]
end

#minimal
Expand All @@ -366,5 +366,5 @@ function obtain_gru_state!(out, variant::Minimal, gru, x, y, W, W_in, b, tmp_arr
mul!(tmp_array[5], W, tmp_array[3] .* x)
@. tmp_array[6] = gru.activation_function[2](tmp_array[4] + tmp_array[5] + b)

@. out = (1 - tmp_array[3]) * x + tmp_array[3] * tmp_array[6]
return @. out = (1 - tmp_array[3]) * x + tmp_array[3] * tmp_array[6]
end
2 changes: 1 addition & 1 deletion src/esn/hybridesn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ function HybridESN(model::KnowledgeModel, train_data::AbstractArray,
input_matrix, bias_vector)
train_data = train_data[:, (washout + 1):end]

HybridESN(res_size, train_data, model, nla_type, input_matrix,
return HybridESN(res_size, train_data, model, nla_type, input_matrix,
inner_res_driver, reservoir_matrix, bias_vector, states_type, washout,
states)
end
Expand Down
2 changes: 1 addition & 1 deletion src/predict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ on the learned relationships in the model.
"""
function Predictive(prediction_data::AbstractArray)
prediction_len = size(prediction_data, 2)
Predictive(prediction_data, prediction_len)
return Predictive(prediction_data, prediction_len)
end

function obtain_prediction(rc::AbstractReservoirComputer, prediction::Generative,
Expand Down
6 changes: 3 additions & 3 deletions src/states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ abstract type AbstractPaddedStates <: AbstractStates end
abstract type NonLinearAlgorithm end

function pad_state!(states_type::AbstractPaddedStates, x_pad, x)
x_pad = vcat(fill(states_type.padding, (1, size(x, 2))), x)
x_pad[1, :] .= states_type.padding
x_pad[2:end, :] .= x
return x_pad
end

function pad_state!(states_type, x_pad, x)
x_pad = x
return x_pad
return x
end

#states types
Expand Down
Loading