Skip to content

Commit

Permalink
Merge pull request #117 from acfr/100-improve-speed-with-back-propaga…
Browse files Browse the repository at this point in the history
…tion

Improve speed with back propagation
  • Loading branch information
nic-barbara committed Aug 4, 2023
2 parents 1006401 + 2b10b9b commit bc59d8f
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 48 deletions.
2 changes: 1 addition & 1 deletion src/Base/ren_params.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ mutable struct DirectRENParams{T}
polar_param::Bool # Whether or not to use polar parameterisation
D22_free ::Bool # Is D22 free or parameterised by (X3,Y3,Z3)?
D22_zero ::Bool # Option to remove feedthrough.
output_map ::Bool
output_map ::Bool # Whether to include output map of REN
end

"""
Expand Down
2 changes: 1 addition & 1 deletion src/ParameterTypes/contracting_ren.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ end

@functor ContractingRENParams (direct, )

function direct_to_explicit(ps::ContractingRENParams{T}, return_h::Bool=false) where T
function direct_to_explicit(ps::ContractingRENParams, return_h::Bool=false)

ϵ = ps.direct.ϵ
ρ = ps.direct.ρ[1]
Expand Down
48 changes: 35 additions & 13 deletions src/ParameterTypes/general_ren.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,23 +119,18 @@ function direct_to_explicit(ps::GeneralRENParams{T}, return_h=false) where T
R1 = Hermitian(R - S * (Q \ S'))
LR = Matrix{T}(cholesky(R1).U)

M = X3'*X3 + Y3 - Y3' + Z3'*Z3 + ϵ*I
if ny >= nu
N = [(I - M) / (I + M); -2*Z3 / (I + M)]
else
N = [((I + M) \ (I - M)) (-2*(I + M) \ Z3')]
end
M = _M_gen(X3, Y3, Z3, ϵ)
N = _N_gen(nu, ny, M, Z3)

D22 = -(Q \ S') + (LQ \ N) * LR
D22 = _D22_gen(Q, S, LQ, LR, N)

# Constructing H. See Eqn 28 of TAC paper
C2_imp = (D22'*Q + S)*C2
D21_imp = (D22'*Q + S)*D21 - D12_imp'
C2_imp = _C2_gen(D22, C2, Q, S)
D21_imp = _D21_gen(D22, D21, D12_imp, Q, S)

𝑅 = R + S*D22 + D22'*S' + D22'*Q*D22

Γ1 = [C2'; D21'; zeros(T, nx, ny)] * Q * [C2 D21 zeros(T, ny, nx)]
Γ2 = [C2_imp'; D21_imp'; B2_imp] * (𝑅 \ [C2_imp D21_imp B2_imp'])
𝑅 = _R_gen(R, S, Q, D22)
Γ1 = _Γ1_gen(nx, ny, C2, D21, Q, T)
Γ2 = _Γ2_gen(C2_imp, D21_imp, B2_imp, 𝑅)

H = x_to_h(X, ϵ, polar_param, ρ) + Γ2 - Γ1

Expand All @@ -144,3 +139,30 @@ function direct_to_explicit(ps::GeneralRENParams{T}, return_h=false) where T
return H

end

# Auto-diff faster through smaller functions
_M_gen(X3, Y3, Z3, ϵ) = X3'*X3 + Y3 - Y3' + Z3'*Z3 + ϵ*I

function _N_gen(nu, ny, M, Z3)
if ny >= nu
return[(I - M) / (I + M); -2*Z3 / (I + M)]
else
return [((I + M) \ (I - M)) (-2*(I + M) \ Z3')]
end
end

_D22_gen(Q, S, LQ, LR, N) = -(Q \ S') + (LQ \ N) * LR

_C2_gen(D22, C2, Q, S) = (D22'*Q + S)*C2

_D21_gen(D22, D21, D12_imp, Q, S) = (D22'*Q + S)*D21 - D12_imp'

_R_gen(R, S, Q, D22) = R + S*D22 + D22'*S' + D22'*Q*D22

function _Γ1_gen(nx, ny, C2, D21, Q, T)
[C2'; D21'; zeros(T, nx, ny)] * Q * [C2 D21 zeros(T, ny, nx)]
end

function _Γ2_gen(C2_imp, D21_imp, B2_imp, 𝑅)
[C2_imp'; D21_imp'; B2_imp] * (𝑅 \ [C2_imp D21_imp B2_imp'])
end
40 changes: 31 additions & 9 deletions src/ParameterTypes/lipschitz_ren.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,25 +101,47 @@ function direct_to_explicit(ps::LipschitzRENParams{T}, return_h=false) where T
if ps.direct.D22_zero
D22 = ps.direct.D22
else
M = X3'*X3 + Y3 - Y3' + Z3'*Z3 + ϵ*I
N = (ny >= nu) ? [(I - M) / (I + M); -2*Z3 / (I + M)] :
[((I + M) \ (I - M)) (-2*(I + M) \ Z3')]
M = _M_lip(X3, Y3, Z3, ϵ)
N = _N_lip(nu, ny, M, Z3)
D22 = γ*N
end

# Constructing H. See Eqn 28 of TAC paper
C2_imp = -(D22')*C2 / γ
D21_imp = -(D22')*D21 / γ - D12_imp'
C2_imp = _C2_lip(D22, C2, γ)
D21_imp = _D21_lip(D22, D21, γ, D12_imp)

𝑅 = -D22'*D22 / γ +* I)

Γ1 = [C2'; D21'; zeros(T, nx, ny)] * [C2 D21 zeros(T, ny, nx)] * (-1/γ)
Γ2 = [C2_imp'; D21_imp'; B2_imp] * (𝑅 \ [C2_imp D21_imp B2_imp'])
𝑅 = _R_lip(D22, γ)
Γ1 = _Γ1_lip(nx, ny, C2, D21, γ, T)
Γ2 = _Γ2_lip(C2_imp, D21_imp, B2_imp, 𝑅)

H = x_to_h(X, ϵ, polar_param, ρ) + Γ2 - Γ1

# Get explicit parameterisation
!return_h && (return hmatrix_to_explicit(ps, H, D22))
return H
end

# Auto-diff faster through smaller functions
_M_lip(X3, Y3, Z3, ϵ) = X3'*X3 + Y3 - Y3' + Z3'*Z3 + ϵ*I

function _N_lip(nu, ny, M, Z3)
if ny >= nu
return [(I - M) / (I + M); -2*Z3 / (I + M)]
else
return [((I + M) \ (I - M)) (-2*(I + M) \ Z3')]
end
end

_C2_lip(D22, C2, γ) = -(D22')*C2 / γ

_D21_lip(D22, D21, γ, D12_imp) = -(D22')*D21 / γ - D12_imp'

_R_lip(D22, γ) = -D22'*D22 / γ +* I)

function _Γ1_lip(nx, ny, C2, D21, γ, T)
[C2'; D21'; zeros(T, nx, ny)] * [C2 D21 zeros(T, ny, nx)] * (-1/γ)
end

function _Γ2_lip(C2_imp, D21_imp, B2_imp, 𝑅)
[C2_imp'; D21_imp'; B2_imp] * (𝑅 \ [C2_imp D21_imp B2_imp'])
end
16 changes: 11 additions & 5 deletions src/ParameterTypes/passive_ren.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ function direct_to_explicit(ps::PassiveRENParams{T}, return_h=false) where T

# System sizes
nu = ps.nu
nx = ps.nx
ny = ps.ny
ν = ps.ν

Expand All @@ -93,14 +92,13 @@ function direct_to_explicit(ps::PassiveRENParams{T}, return_h=false) where T
# Constructing D22 for incrementally passive and incrementally strictly input passive.
# See Eqns 31-33 of TAC paper
# Currently converts to Hermitian to avoid numerical conditioning issues
M = X3'*X3 + Y3 - Y3' + ϵ*I
M = _M_pass(X3, Y3, ϵ)

D22 = ν*Matrix(I, ny,nu) + M
D21_imp = D21 - D12_imp'

𝑅 = -2ν * Matrix(I, nu, nu) + D22 + D22'

Γ2 = [C2'; D21_imp'; B2_imp] * (𝑅 \ [C2 D21_imp B2_imp'])
𝑅 = _R_pass(nu, D22, ν)
Γ2 = _Γ2_pass(C2, D21_imp, B2_imp, 𝑅)

H = x_to_h(X, ϵ, polar_param, ρ) + Γ2

Expand All @@ -109,3 +107,11 @@ function direct_to_explicit(ps::PassiveRENParams{T}, return_h=false) where T
return H

end

_M_pass(X3, Y3, ϵ) = X3'*X3 + Y3 - Y3' + ϵ*I

_R_pass(nu, D22, ν) = -2ν * Matrix(I, nu, nu) + D22 + D22'

function _Γ2_pass(C2, D21_imp, B2_imp, 𝑅)
[C2'; D21_imp'; B2_imp] * (𝑅 \ [C2 D21_imp B2_imp'])
end
33 changes: 20 additions & 13 deletions src/ParameterTypes/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Convert direct parameterisation of REN from H matrix (Eqn. 23 of [Revay et al. (
- `H::Matrix{T}`: H-matrix to convert.
- `D22::Matrix{T}=zeros(T,0,0))`: Optionally include `D22` matrix. If empty (default), `D22` taken from `ps.direct.D22`.
"""
function hmatrix_to_explicit(ps::AbstractRENParams, H::AbstractMatrix{T}, D22::AbstractMatrix{T} = zeros(T,0,0)) where T<:Real
function hmatrix_to_explicit(ps::AbstractRENParams, H::AbstractMatrix{T}, D22::AbstractMatrix{T} = zeros(T,0,0)) where T

# System sizes
nx = ps.nx
Expand All @@ -54,7 +54,7 @@ function hmatrix_to_explicit(ps::AbstractRENParams, H::AbstractMatrix{T}, D22::A
Y1 = ps.direct.Y1

# Extract sections of H matrix
# Using @view is faster but not supported by CUDA
# Using @view is slower in reverse mode?
H11 = H[1:nx, 1:nx]
H22 = H[nx + 1:nx + nv, nx + 1:nx + nv]
H33 = H[nx + nv + 1:2nx + nv, nx + nv + 1:2nx + nv]
Expand All @@ -65,40 +65,47 @@ function hmatrix_to_explicit(ps::AbstractRENParams, H::AbstractMatrix{T}, D22::A
# Construct implicit model parameters
P_imp = H33
F = H31
E = (H11 + P_imp/ᾱ^2 + Y1 - Y1')/2
E = _E(H11, P_imp, ᾱ, Y1)

# Equilibrium network parameters
B1_imp = H32
C1_imp = -H21
Λ_inv = ( (1 ./ diag(H22)) * 2)
Λ_inv = _Λ_inv(H22)
D11_imp = (-tril(H22, -1))

# Construct the explicit model
A = E \ F
B1 = E \ B1_imp
B2 = E \ ps.direct.B2

# Current versions of Julia behave poorly when broadcasting over 0-dim arrays
C1 = (nv == 0) ? zeros(T,0,nx) : broadcast(*, Λ_inv, C1_imp)
D11 = (nv == 0) ? zeros(T,0,0) : broadcast(*, Λ_inv, D11_imp)
D12 = (nv == 0) ? zeros(T,0,nu) : broadcast(*, Λ_inv, ps.direct.D12)
# Equilibrium layer matrices
C1 = _C1(nv, nx, Λ_inv, C1_imp, T)
D11 = _D11(nv, Λ_inv, D11_imp, T)
D12 = _D12(nv, nu, Λ_inv, ps.direct.D12, T)

# Output layer
C2 = ps.direct.C2
D21 = ps.direct.D21

isempty(D22) && (D22 = ps.direct.D22)

# Biases
bx = ps.direct.bx
bv = ps.direct.bv
by = ps.direct.by

return ExplicitRENParams{T}(A, B1, B2, C1, C2, D11, D12, D21, D22, bx, bv, by)

end

function x_to_h(X::AbstractMatrix{T}, ϵ::T, polar_param::Bool, ρ::T) where T
polar_param ?^2)*(X'*X) / norm(X)^2 + ϵ*I : X'*X + ϵ*I
end
# Splitting operations into functions speeds up auto-diff
_E(H11, P_imp, ᾱ, Y1) = (H11 + P_imp/ᾱ^2 + Y1 - Y1')/2
_Λ_inv(H22) = ( (1 ./ diag(H22)) * 2)

# Current versions of Julia behave poorly when broadcasting over 0-dim arrays
_C1(nv, nx, Λ_inv, C1_imp, T) = (nv == 0) ? zeros(T,0,nx) : broadcast(*, Λ_inv, C1_imp)
_D11(nv, Λ_inv, D11_imp, T) = (nv == 0) ? zeros(T,0,0) : broadcast(*, Λ_inv, D11_imp)
_D12(nv, nu, Λ_inv, D12_imp, T) = (nv == 0) ? zeros(T,0,nu) : broadcast(*, Λ_inv, D12_imp)

x_to_h(X, ϵ, polar_param, ρ) = polar_param ?^2)*(X'*X) / norm(X)^2 + ϵ*I : X'*X + ϵ*I

"""
set_output_zero!(m::AbstractRENParams)
Expand Down
17 changes: 11 additions & 6 deletions src/Wrappers/REN/ren.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,23 @@ function (m::AbstractREN{T})(
) where T

# Allocate bias vectors to avoid error when nv = 0 or nx = 0
# TODO: if statement (or equivalent) makes backpropagation slower. Can we avoid this?
bv = (m.nv == 0) ? 0 : explicit.bv
bx = (m.nx == 0) ? 0 : explicit.bx
bv = _bias(m.nv, explicit.bv)
bx = _bias(m.nx, explicit.bx)
by = explicit.by

b = explicit.C1 * xt + explicit.D12 * ut .+ bv
b = _b(explicit.C1, explicit.D12, xt, ut, bv)
wt = tril_eq_layer(m.nl, explicit.D11, b)
xt1 = explicit.A * xt + explicit.B1 * wt + explicit.B2 * ut .+ bx
yt = explicit.C2 * xt + explicit.D21 * wt + explicit.D22 * ut .+ explicit.by
xt1 = _xt1(explicit.A, explicit.B1, explicit.B2, xt, wt, ut, bx)
yt = _yt(explicit.C2, explicit.D21, explicit.D22, xt, wt, ut, by)

return xt1, yt
end

_bias(n, b) = n == 0 ? 0 : b
_b(C1, D12, xt, ut, bv) = C1 * xt + D12 * ut .+ bv
_xt1(A, B1, B2, xt, wt, ut, bx) = A * xt + B1 * wt + B2 * ut .+ bx
_yt(C2, D21, D22, xt, wt, ut, by) = C2 * xt + D21 * wt + D22 * ut .+ by

"""
init_states(m::AbstractREN, nbatches; rng=nothing)
Expand Down

0 comments on commit bc59d8f

Please sign in to comment.