Skip to content

Commit

Permalink
Merge pull request #101 from acfr/bugfix/linear-REN
Browse files Browse the repository at this point in the history
Bugfix: error when taking gradients with `nv = 0` or `nx = 0`
  • Loading branch information
nic-barbara committed Jul 13, 2023
2 parents dfc1e4c + be6281e commit 966f852
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 9 deletions.
1 change: 1 addition & 0 deletions src/ParameterTypes/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ function hmatrix_to_explicit(ps::AbstractRENParams, H::AbstractMatrix{T}, D22::A
B1 = E \ B1_imp
B2 = E \ ps.direct.B2

# Current versions of Julia behave poorly when broadcasting over 0-dim arrays
C1 = (nv == 0) ? zeros(T,0,nx) : broadcast(*, Λ_inv, C1_imp)
D11 = (nv == 0) ? zeros(T,0,0) : broadcast(*, Λ_inv, D11_imp)
D12 = (nv == 0) ? zeros(T,0,nu) : broadcast(*, Λ_inv, ps.direct.D12)
Expand Down
2 changes: 1 addition & 1 deletion src/RobustNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ module RobustNeuralNetworks

using Flux
using LinearAlgebra
using MatrixEquations: lyapd, plyapd
using MatrixEquations: lyapd
using Random
using Zygote: pullback, Buffer
using Zygote: @adjoint
Expand Down
9 changes: 7 additions & 2 deletions src/Wrappers/REN/ren.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,14 @@ function (m::AbstractREN{T})(
explicit::ExplicitRENParams{T}
) where T

b = explicit.C1 * xt + explicit.D12 * ut .+ explicit.bv
# Allocate bias vectors to avoid error when nv = 0 or nx = 0
# TODO: if statement (or equivalent) makes backpropagation slower. Can we avoid this?
bv = (m.nv == 0) ? 0 : explicit.bv
bx = (m.nx == 0) ? 0 : explicit.bx

b = explicit.C1 * xt + explicit.D12 * ut .+ bv
wt = tril_eq_layer(m.nl, explicit.D11, b)
xt1 = explicit.A * xt + explicit.B1 * wt + explicit.B2 * ut .+ explicit.bx
xt1 = explicit.A * xt + explicit.B1 * wt + explicit.B2 * ut .+ bx
yt = explicit.C2 * xt + explicit.D21 * wt + explicit.D22 * ut .+ explicit.by

return xt1, yt
Expand Down
2 changes: 0 additions & 2 deletions test/Wrappers/diff_lbdn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ using Random
using RobustNeuralNetworks
using Test

# include("../test_utils.jl")

"""
Test that backpropagation runs and parameters change
"""
Expand Down
6 changes: 2 additions & 4 deletions test/Wrappers/diff_ren.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@ using Random
using RobustNeuralNetworks
using Test

# include("../test_utils.jl")

"""
Test that backpropagation runs and parameters change
"""
batches = 10
nu, nx, nv, ny = 4, 5, 10, 3
nu, nx, nv, ny = 4, 5, 10, 2
γ = 10
ren_ps = LipschitzRENParams{Float32}(nu, nx, nv, ny, γ)
ren_ps = LipschitzRENParams{Float64}(nu, nx, nv, ny, γ)
model = DiffREN(ren_ps)

# Dummy data
Expand Down
33 changes: 33 additions & 0 deletions test/Wrappers/zero_dim.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# This file is a part of RobustNeuralNetworks.jl. License is MIT: https://github.com/acfr/RobustNeuralNetworks.jl/blob/main/LICENSE

using Flux
using Random
using RobustNeuralNetworks
using Test

"""
Test that backpropagation runs when nx = 0 and nv = 0
"""
batches = 10
nu, nx, nv, ny = 4, 0, 0, 2
γ = 10
model_ps = LipschitzRENParams{Float64}(nu, nx, nv, ny, γ)

# Dummy data
us = randn(nu, batches)
ys = randn(ny, batches)
data = [(us, ys)]

# Dummy loss function just for testing
function loss(model_ps, u, y)
m = REN(model_ps)
x0 = init_states(m, size(u,2))
x1, y1 = m(x0, u)
return Flux.mse(y1, y) + sum(x1.^2)
end

# Make sure batch update actually runs
opt_state = Flux.setup(Adam(0.01), model_ps)
gs = Flux.gradient(loss, model_ps, us, ys)
Flux.update!(opt_state, model_ps, gs[1])
@test !isempty(gs)
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@ using Test
include("Wrappers/wrap_ren.jl")
include("Wrappers/diff_ren.jl")
include("Wrappers/diff_lbdn.jl")
include("Wrappers/zero_dim.jl")

end

0 comments on commit 966f852

Please sign in to comment.