Skip to content

Commit

Permalink
Merge pull request #116 from acfr/75-improve-package-loading-speed
Browse files Browse the repository at this point in the history
Attempt to improve package loading speeds
  • Loading branch information
nic-barbara committed Aug 3, 2023
2 parents f0b43dd + 6b2efdc commit 1006401
Show file tree
Hide file tree
Showing 14 changed files with 49 additions and 139 deletions.
8 changes: 4 additions & 4 deletions src/Base/lbdn_params.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ See also [`DenseLBDNParams`](@ref).
"""
function DirectLBDNParams{T}(
nu::Int, nh::Vector{Int}, ny::Int, γ::Number = T(1);
initW::Function = Flux.glorot_normal,
initb::Function = Flux.glorot_normal,
initW::Function = glorot_normal,
initb::Function = glorot_normal,
learn_γ::Bool = false,
rng::AbstractRNG = Random.GLOBAL_RNG
) where T
Expand All @@ -83,8 +83,8 @@ function DirectLBDNParams{T}(
)
end

Flux.@functor DirectLBDNParams
function Flux.trainable(m::DirectLBDNParams)
@functor DirectLBDNParams
function trainable(m::DirectLBDNParams)
if m.learn_γ
return (XY=m.XY, α=m.α, d=m.d, b=m.b, log_γ=m.log_γ)
else
Expand Down
31 changes: 5 additions & 26 deletions src/Base/ren_params.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,9 @@ function DirectRENParams{T}(
)
end

Flux.@functor DirectRENParams
@functor DirectRENParams

function Flux.trainable(m::DirectRENParams)
function trainable(m::DirectRENParams)

# Field names of trainable params, exclude ρ if needed
if !m.output_map
Expand All @@ -245,39 +245,18 @@ function Flux.trainable(m::DirectRENParams)
indx = length.(ps) .!= 0
ps, fs = ps[indx], fs[indx]

# Flux.trainable() must return a NamedTuple
# Optimisers.trainable() must return a NamedTuple
return NamedTuple{tuple(fs...)}(ps)
end

function Flux.gpu(M::DirectRENParams{T}) where T
# TODO: Test and complete this
if T != Float32
println("Moving type: ", T, " to gpu may not be supported. Try Float32!")
end
return DirectRENParams{T}(
gpu(M.ρ), gpu(M.X), gpu(M.Y1), gpu(M.X3), gpu(M.Y3),
gpu(M.Z3), gpu(M.B2), gpu(M.C2), gpu(M.D12), gpu(M.D21),
gpu(M.D22), gpu(M.bx), gpu(M.bv), gpu(M.by),
M.ϵ, M.polar_param, M.D22_free, M.D22_zero
)
end

function Flux.cpu(M::DirectRENParams{T}) where T
# TODO: Test and complete this
return DirectRENParams{T}(
cpu(M.ρ), cpu(M.X), cpu(M.Y1), cpu(M.X3), cpu(M.Y3),
cpu(M.Z3), cpu(M.B2), cpu(M.C2), cpu(M.D12), cpu(M.D21),
cpu(M.D22), cpu(M.bx), cpu(M.bv), cpu(M.by),
M.ϵ, M.polar_param, M.D22_free, M.D22_zero
)
end

"""
==(ps1::DirectRENParams, ps2::DirectRENParams)
Define equality for two objects of type `DirectRENParams`.
Checks if all *relevant* parameters are equal. For example, if `D22` is fixed to `0` then the values of `X3, Y3, Z3` are not important and are ignored.
This is currently not used. Might be useful in the future...
"""
function ==(ps1::DirectRENParams, ps2::DirectRENParams)

Expand Down
4 changes: 2 additions & 2 deletions src/Base/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
Generate matrices or vectors from the Glorot normal distribution.
"""
glorot_normal(n::Int, m::Int; T=Float64, rng=Random.GLOBAL_RNG) =
convert.(T, Flux.glorot_normal(rng, n, m))
convert.(T, glorot_normal(rng, n, m))
glorot_normal(n::Int; T=Float64, rng=Random.GLOBAL_RNG) =
convert.(T, Flux.glorot_normal(rng, n))
convert.(T, glorot_normal(rng, n))
28 changes: 4 additions & 24 deletions src/ParameterTypes/contracting_ren.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ The parameters can be used to construct an explicit [`REN`](@ref) model that has
# Keyword arguments
- `nl::Function=Flux.relu`: Sector-bounded static nonlinearity.
- `nl::Function=relu`: Sector-bounded static nonlinearity.
- `αbar::T=1`: Upper bound on the contraction rate with `ᾱ ∈ (0,1]`.
Expand All @@ -35,7 +35,7 @@ See also [`GeneralRENParams`](@ref), [`LipschitzRENParams`](@ref), [`PassiveRENP
"""
function ContractingRENParams{T}(
nu::Int, nx::Int, nv::Int, ny::Int;
nl::Function = Flux.relu,
nl::Function = relu,
αbar::T = T(1),
init = :random,
polar_param::Bool = true,
Expand All @@ -60,19 +60,15 @@ end

@doc raw"""
ContractingRENParams(nv, A, B, C, D; ...)
Alternative constructor for `ContractingRENParams` that initialises the
REN from a **stable** discrete-time linear system with state-space model
```math
\begin{align*}
x_{t+1} &= Ax_t + Bu_t \\
y_t &= Cx_t + Du_t.
\end{align*}
```
[TODO:] This method may be removed in a later edition of the package.
[TODO:] This method has not been used or tested in a while. If you find it useful, please reach out to us and we will add full support and testing! :)
[TODO:] Make compatible with αbar ≠ 1.0.
"""
function ContractingRENParams(
Expand Down Expand Up @@ -150,23 +146,7 @@ function ContractingRENParams(

end

Flux.@functor ContractingRENParams (direct, )

function Flux.gpu(m::ContractingRENParams{T}) where T
# TODO: Test and complete this
direct_ps = Flux.gpu(m.direct)
return ContractingRENParams{T}(
m.nl, m.nu, m.nx, m.nv, m.ny, direct_ps, m.αbar
)
end

function Flux.cpu(m::ContractingRENParams{T}) where T
# TODO: Test and complete this
direct_ps = Flux.cpu(m.direct)
return ContractingRENParams{T}(
m.nl, m.nu, m.nx, m.nv, m.ny, direct_ps, m.αbar
)
end
@functor ContractingRENParams (direct, )

function direct_to_explicit(ps::ContractingRENParams{T}, return_h::Bool=false) where T

Expand Down
14 changes: 5 additions & 9 deletions src/ParameterTypes/dense_lbdn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,25 @@ This is the equivalent of a multi-layer perceptron (eg: `Flux.Dense`) with a gua
# Keyword arguments:
- `nl::Function=Flux.relu`: Sector-bounded static nonlinearity.
- `nl::Function=relu`: Sector-bounded static nonlinearity.
- `learn_γ::Bool=false:` Whether to make the Lipschitz bound γ a learnable parameter.
See [`DirectLBDNParams`](@ref) for documentation of keyword arguments `initW`, `initb`, `rng`.
"""
function DenseLBDNParams{T}(
nu::Int, nh::Vector{Int}, ny::Int, γ::Number = T(1);
nl::Function = Flux.relu,
initW::Function = Flux.glorot_normal,
initb::Function = Flux.glorot_normal,
nl::Function = relu,
initW::Function = glorot_normal,
initb::Function = glorot_normal,
learn_γ::Bool = false,
rng::AbstractRNG = Random.GLOBAL_RNG
) where T
direct = DirectLBDNParams{T}(nu, nh, ny, γ; initW, initb, learn_γ, rng)
return DenseLBDNParams{T}(nl, nu, nh, ny, direct)
end

Flux.@functor DenseLBDNParams (direct, )
@functor DenseLBDNParams (direct, )

function direct_to_explicit(ps::DenseLBDNParams{T}) where T

Expand Down Expand Up @@ -112,7 +112,3 @@ function get_AB(
return copy(buf_A), copy(buf_B)

end

# TODO: Add GPU compatibility
# Flux.cpu() ...
# Flux.gpu() ...
22 changes: 3 additions & 19 deletions src/ParameterTypes/general_ren.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Behavioural constraints are encoded by the matrices `Q,S,R` in an incremental In
# Keyword arguments
- `nl::Function=Flux.relu`: Sector-bounded static nonlinearity.
- `nl::Function=relu`: Sector-bounded static nonlinearity.
- `αbar::T=1`: Upper bound on the contraction rate with `ᾱ ∈ (0,1]`.
Expand All @@ -42,7 +42,7 @@ See also [`ContractingRENParams`](@ref), [`LipschitzRENParams`](@ref), [`Passive
function GeneralRENParams{T}(
nu::Int, nx::Int, nv::Int, ny::Int,
Q::Matrix{T}, S::Matrix{T}, R::Matrix{T};
nl::Function = Flux.relu,
nl::Function = relu,
αbar::T = T(1),
init = :random,
polar_param::Bool = true,
Expand Down Expand Up @@ -82,23 +82,7 @@ function GeneralRENParams{T}(

end

Flux.@functor GeneralRENParams (direct, )

function Flux.gpu(m::GeneralRENParams{T}) where T
# TODO: Test and complete this
direct_ps = Flux.gpu(m.direct)
return GeneralRENParams{T}(
m.nl, m.nu, m.nx, m.nv, m.ny, direct_ps, m.αbar, m.Q, m.S, m.R
)
end

function Flux.cpu(m::GeneralRENParams{T}) where T
# TODO: Test and complete this
direct_ps = Flux.cpu(m.direct)
return GeneralRENParams{T}(
m.nl, m.nu, m.nx, m.nv, m.ny, direct_ps, m.αbar, m.Q, m.S, m.R
)
end
@functor GeneralRENParams (direct, )

function direct_to_explicit(ps::GeneralRENParams{T}, return_h=false) where T

Expand Down
24 changes: 4 additions & 20 deletions src/ParameterTypes/lipschitz_ren.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Construct direct parameterisation of a REN with a Lipschitz bound of γ.
# Keyword arguments
- `nl::Function=Flux.relu`: Sector-bounded static nonlinearity.
- `nl::Function=relu`: Sector-bounded static nonlinearity.
- `αbar::T=1`: Upper bound on the contraction rate with `ᾱ ∈ (0,1]`.
Expand All @@ -38,7 +38,7 @@ See also [`GeneralRENParams`](@ref), [`ContractingRENParams`](@ref), [`PassiveRE
"""
function LipschitzRENParams{T}(
nu::Int, nx::Int, nv::Int, ny::Int, γ::Number;
nl::Function = Flux.relu,
nl::Function = relu,
αbar::T = T(1),
learn_γ::Bool = false,
init = :random,
Expand All @@ -65,27 +65,11 @@ function LipschitzRENParams{T}(

end

Flux.@functor LipschitzRENParams
function Flux.trainable(m::LipschitzRENParams)
@functor LipschitzRENParams
function trainable(m::LipschitzRENParams)
m.learn_γ ? (direct = m.direct, γ = m.γ) : (direct = m.direct,)
end

function Flux.gpu(m::LipschitzRENParams{T}) where T
# TODO: Test and complete this
direct_ps = Flux.gpu(m.direct)
return LipschitzRENParams{T}(
m.nl, m.nu, m.nx, m.nv, m.ny, direct_ps, m.αbar, m.γ
)
end

function Flux.cpu(m::LipschitzRENParams{T}) where T
# TODO: Test and complete this
direct_ps = Flux.cpu(m.direct)
return LipschitzRENParams{T}(
m.nl, m.nu, m.nx, m.nv, m.ny, direct_ps, m.αbar, m.γ
)
end

function direct_to_explicit(ps::LipschitzRENParams{T}, return_h=false) where T

# System sizes
Expand Down
22 changes: 3 additions & 19 deletions src/ParameterTypes/passive_ren.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Construct direct parameterisation of a passive REN.
- `ν::T=0`: Passivity parameter. Use ν>0 for incrementally strictly input passive model, and ν == 0 for incrementally passive model.
- `nl::Function=Flux.relu`: Sector-bounded static nonlinearity.
- `nl::Function=relu`: Sector-bounded static nonlinearity.
- `αbar::T=1`: Upper bound on the contraction rate with `ᾱ ∈ (0,1]`.
Expand All @@ -38,7 +38,7 @@ See also [`GeneralRENParams`](@ref), [`ContractingRENParams`](@ref), [`Lipschitz
function PassiveRENParams{T}(
nu::Int, nx::Int, nv::Int, ny::Int;
ν::T = T(0),
nl::Function = Flux.relu,
nl::Function = relu,
αbar::T = T(1),
init = :random,
polar_param::Bool = true,
Expand All @@ -64,23 +64,7 @@ function PassiveRENParams{T}(

end

Flux.@functor PassiveRENParams (direct, )

function Flux.gpu(m::PassiveRENParams{T}) where T
# TODO: Test and complete this
direct_ps = Flux.gpu(m.direct)
return PassiveRENParams{T}(
m.nl, m.nu, m.nx, m.nv, m.ny, direct_ps, m.αbar, m.ν
)
end

function Flux.cpu(m::PassiveRENParams{T}) where T
# TODO: Test and complete this
direct_ps = Flux.cpu(m.direct)
return PassiveRENParams{T}(
m.nl, m.nu, m.nx, m.nv, m.ny, direct_ps, m.αbar, m.ν
)
end
@functor PassiveRENParams (direct, )

function direct_to_explicit(ps::PassiveRENParams{T}, return_h=false) where T

Expand Down
14 changes: 8 additions & 6 deletions src/RobustNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,20 @@ module RobustNeuralNetworks
############ Package dependencies ############

using ChainRulesCore: NoTangent, @non_differentiable
using Flux
using Flux: relu, identity, @functor
using LinearAlgebra
using MatrixEquations: lyapd
using Random
using Zygote: Buffer

import Base.:(==)
import ChainRulesCore: rrule
import Flux.gpu, Flux.cpu
import Flux: trainable, glorot_normal

# Note: to remove explicit dependency on Flux.jl, use the following
# using Functors: @functor
# using NNlib: relu, identity
# import Optimisers.trainable
# and re-write `glorot_normal` yourself.


############ Abstract types ############
Expand Down Expand Up @@ -108,7 +113,4 @@ export init_states
export set_output_zero!
export update_explicit!

# Extended functions
# TODO: Need to export things like gpu, cpu, ==, etc.

end # end RobustNeuralNetworks
2 changes: 1 addition & 1 deletion src/Wrappers/LBDN/diff_lbdn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ function (m::DiffLBDN)(u::AbstractVecOrMat)
return m(u, explicit)
end

Flux.@functor DiffLBDN (params, )
@functor DiffLBDN (params, )

function set_output_zero!(m::DiffLBDN)
set_output_zero!(m.params)
Expand Down
Loading

0 comments on commit 1006401

Please sign in to comment.