Skip to content

Commit

Permalink
Add Module wrapper to ensure serializable for Turing (#109)
Browse files Browse the repository at this point in the history
* Fixes issue #70 by using wrapper
  • Loading branch information
jlperla committed Feb 28, 2022
1 parent e5a4318 commit 8a17d9a
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 98 deletions.
60 changes: 30 additions & 30 deletions src/generate_perturbation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ function generate_perturbation(m::PerturbationModel, p_d, p_f, order::Val{1} = V
settings = PerturbationSolverSettings())
@assert cache.p_d_symbols == collect(Symbol.(keys(p_d)))

p = isnothing(p_f) ? p_d : order_vector_by_symbols(merge(p_d, p_f), m.mod.p_symbols)
p = isnothing(p_f) ? p_d : order_vector_by_symbols(merge(p_d, p_f), m.mod.m.p_symbols)

# solver type provided to all callbacks
ret = calculate_steady_state!(m, cache, settings, p)
Expand All @@ -39,7 +39,7 @@ function generate_perturbation(m::PerturbationModel, p_d, p_f, order::Val{2};
@assert cache.p_d_symbols == collect(Symbol.(keys(p_d)))
@assert cache.order == Val(2)

p = isnothing(p_f) ? p_d : order_vector_by_symbols(merge(p_d, p_f), m.mod.p_symbols)
p = isnothing(p_f) ? p_d : order_vector_by_symbols(merge(p_d, p_f), m.mod.m.p_symbols)

# Calculate the first-order perturbation
sol_first = generate_perturbation(m, p_d, p_f, Val(1); cache, settings)
Expand All @@ -62,30 +62,30 @@ function calculate_steady_state!(m::PerturbationModel, c, settings, p)

(settings.print_level > 2) && println("Calculating steady state")
try
if !isnothing(m.mod.ȳ!) && !isnothing(m.mod.x̄!) # use closed form if possible
m.mod.ȳ!(c.y, p)
m.mod.x̄!(c.x, p)
isnothing(m.mod.ȳ_p!) ||
fill_array_by_symbol_dispatch(m.mod.ȳ_p!, c.y_p, c.p_d_symbols, p)
isnothing(m.mod.x̄_p!) ||
fill_array_by_symbol_dispatch(m.mod.x̄_p!, c.x_p, c.p_d_symbols, p)
elseif !isnothing(m.mod.steady_state!) # use user-provided calculation otherwise
m.mod.steady_state!(c.y, c.x, p)
if !isnothing(m.mod.m.ȳ!) && !isnothing(m.mod.m.x̄!) # use closed form if possible
m.mod.m.ȳ!(c.y, p)
m.mod.m.x̄!(c.x, p)
isnothing(m.mod.m.ȳ_p!) ||
fill_array_by_symbol_dispatch(m.mod.m.ȳ_p!, c.y_p, c.p_d_symbols, p)
isnothing(m.mod.m.x̄_p!) ||
fill_array_by_symbol_dispatch(m.mod.m.x̄_p!, c.x_p, c.p_d_symbols, p)
elseif !isnothing(m.mod.m.steady_state!) # use user-provided calculation otherwise
m.mod.m.steady_state!(c.y, c.x, p)
else # fallback is to solve system of equations from user-provided initial condition
y_0 = zeros(n_y)
x_0 = zeros(n_x)
m.mod.ȳ_iv!(y_0, p)
m.mod.x̄_iv!(x_0, p)
m.mod.m.ȳ_iv!(y_0, p)
m.mod.m.x̄_iv!(x_0, p)
w_0 = [y_0; x_0]

if isnothing(m.mod.H̄_w!) # no jacobian
nlsol = nlsolve((H, w) -> m.mod.H̄!(H, w, p), w_0;
if isnothing(m.mod.m.H̄_w!) # no jacobian
nlsol = nlsolve((H, w) -> m.mod.m.H̄!(H, w, p), w_0;
DifferentiableStateSpaceModels.nlsolve_options(settings)...)
else
J_0 = zeros(n, n)
F_0 = zeros(n)
df = OnceDifferentiable((H, w) -> m.mod.H̄!(H, w, p),
(J, w) -> m.mod.H̄_w!(J, w, p), w_0, F_0, J_0) # TODO: the buffer to use for the w_0 is unclear?
df = OnceDifferentiable((H, w) -> m.mod.m.H̄!(H, w, p),
(J, w) -> m.mod.m.H̄_w!(J, w, p), w_0, F_0, J_0) # TODO: the buffer to use for the w_0 is unclear?
nlsol = nlsolve(df, w_0;
DifferentiableStateSpaceModels.nlsolve_options(settings)...)
end
Expand Down Expand Up @@ -123,13 +123,13 @@ function evaluate_first_order_functions!(m, c, settings, p)
try
@unpack y, x = c # Precondition: valid (y, x) steady states

m.mod.H_yp!(c.H_yp, y, x, p)
m.mod.H_y!(c.H_y, y, x, p)
m.mod.H_xp!(c.H_xp, y, x, p)
m.mod.H_x!(c.H_x, y, x, p)
m.mod.Γ!(c.Γ, p)
maybe_call_function(m.mod.Ω!, c.Ω, p) # supports m.mod.Ω! = nothing
(length(c.p_d_symbols) > 0) && m.mod.Ψ!(c.Ψ, y, x, p)
m.mod.m.H_yp!(c.H_yp, y, x, p)
m.mod.m.H_y!(c.H_y, y, x, p)
m.mod.m.H_xp!(c.H_xp, y, x, p)
m.mod.m.H_x!(c.H_x, y, x, p)
m.mod.m.Γ!(c.Γ, p)
maybe_call_function(m.mod.m.Ω!, c.Ω, p) # supports m.mod.m.Ω! = nothing
(length(c.p_d_symbols) > 0) && m.mod.m.Ψ!(c.Ψ, y, x, p)
catch e
if e isa DomainError
settings.print_level == 0 || display(e)
Expand All @@ -146,12 +146,12 @@ function evaluate_second_order_functions!(m, c, settings, p)
(settings.print_level > 2) && println("Evaluating second-order functions into cache")
try
@unpack y, x = c # Precondition: valid (y, x) steady states
(length(c.p_d_symbols) == 0) && m.mod.Ψ!(c.Ψ, y, x, p) # would have been called otherwise in first_order_functions
m.mod.Ψ!(c.Ψ, y, x, p)
m.mod.Ψ_yp!(c.Ψ_yp, y, x, p)
m.mod.Ψ_y!(c.Ψ_y, y, x, p)
m.mod.Ψ_xp!(c.Ψ_xp, y, x, p)
m.mod.Ψ_x!(c.Ψ_x, y, x, p)
(length(c.p_d_symbols) == 0) && m.mod.m.Ψ!(c.Ψ, y, x, p) # would have been called otherwise in first_order_functions
m.mod.m.Ψ!(c.Ψ, y, x, p)
m.mod.m.Ψ_yp!(c.Ψ_yp, y, x, p)
m.mod.m.Ψ_y!(c.Ψ_y, y, x, p)
m.mod.m.Ψ_xp!(c.Ψ_xp, y, x, p)
m.mod.m.Ψ_x!(c.Ψ_x, y, x, p)
catch e
if e isa DomainError
settings.print_level == 0 || display(e)
Expand Down
22 changes: 11 additions & 11 deletions src/generate_perturbation_derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ function generate_perturbation_derivatives!(m, p_d, p_f, cache::AbstractSolverCa
settings = PerturbationSolverSettings())
@assert cache.p_d_symbols == collect(Symbol.(keys(p_d)))

p = isnothing(p_f) ? p_d : order_vector_by_symbols(merge(p_d, p_f), m.mod.p_symbols)
p = isnothing(p_f) ? p_d : order_vector_by_symbols(merge(p_d, p_f), m.mod.m.p_symbols)

# Fill in derivatives
ret = evaluate_first_order_functions_p!(m, cache, settings, p)
Expand All @@ -17,7 +17,7 @@ function generate_perturbation_derivatives!(m, p_d, p_f, cache::AbstractSolverCa
settings = PerturbationSolverSettings())
@assert cache.p_d_symbols == collect(Symbol.(keys(p_d)))

p = isnothing(p_f) ? p_d : order_vector_by_symbols(merge(p_d, p_f), m.mod.p_symbols)
p = isnothing(p_f) ? p_d : order_vector_by_symbols(merge(p_d, p_f), m.mod.m.p_symbols)

# Fill in derivatives, first by calling the first-order
# Fill in derivatives
Expand All @@ -40,14 +40,14 @@ function evaluate_first_order_functions_p!(m, c, settings, p)
try
@unpack y, x = c # Precondition: valid (y, x) steady states
isnothing(c.H_p) ||
fill_array_by_symbol_dispatch(m.mod.H_p!, c.H_p, c.p_d_symbols, y, x, p) #not required if steady_state_p!
fill_array_by_symbol_dispatch(m.mod.H_yp_p!, c.H_yp_p, c.p_d_symbols, y, x, p)
fill_array_by_symbol_dispatch(m.mod.H_y_p!, c.H_y_p, c.p_d_symbols, y, x, p)
fill_array_by_symbol_dispatch(m.mod.H_xp_p!, c.H_xp_p, c.p_d_symbols, y, x, p)
fill_array_by_symbol_dispatch(m.mod.H_x_p!, c.H_x_p, c.p_d_symbols, y, x, p)
fill_array_by_symbol_dispatch(m.mod.Γ_p!, c.Γ_p, c.p_d_symbols, p)
fill_array_by_symbol_dispatch(m.mod.m.H_p!, c.H_p, c.p_d_symbols, y, x, p) #not required if steady_state_p!
fill_array_by_symbol_dispatch(m.mod.m.H_yp_p!, c.H_yp_p, c.p_d_symbols, y, x, p)
fill_array_by_symbol_dispatch(m.mod.m.H_y_p!, c.H_y_p, c.p_d_symbols, y, x, p)
fill_array_by_symbol_dispatch(m.mod.m.H_xp_p!, c.H_xp_p, c.p_d_symbols, y, x, p)
fill_array_by_symbol_dispatch(m.mod.m.H_x_p!, c.H_x_p, c.p_d_symbols, y, x, p)
fill_array_by_symbol_dispatch(m.mod.m.Γ_p!, c.Γ_p, c.p_d_symbols, p)
isnothing(c.Ω_p) ||
fill_array_by_symbol_dispatch(m.mod.Ω_p!, c.Ω_p, c.p_d_symbols, p)
fill_array_by_symbol_dispatch(m.mod.m.Ω_p!, c.Ω_p, c.p_d_symbols, p)
catch e
if e isa DomainError
settings.print_level == 0 || display(e)
Expand All @@ -65,7 +65,7 @@ function evaluate_second_order_functions_p!(m, c, settings, p)
println("Evaluating second-order function derivatives into cache")
try
@unpack y, x = c # Precondition: valid (y, x) steady states
fill_array_by_symbol_dispatch(m.mod.Ψ_p!, c.Ψ_p, c.p_d_symbols, y, x, p)
fill_array_by_symbol_dispatch(m.mod.m.Ψ_p!, c.Ψ_p, c.p_d_symbols, y, x, p)
catch e
if e isa DomainError
settings.print_level == 0 || display(e)
Expand All @@ -86,7 +86,7 @@ function solve_first_order_p!(m, c, settings)

buff = c.first_order_solver_p_buffer
try
if isnothing(m.mod.ȳ_p!) && isnothing(m.mod.x̄_p!)
if isnothing(m.mod.m.ȳ_p!) && isnothing(m.mod.m.x̄_p!)
# Zeroth-order derivatives if not provided
# Calculating c.y_p, c.x_p
A_zero = [c.H_y + c.H_yp c.H_x + c.H_xp]
Expand Down
50 changes: 32 additions & 18 deletions src/types.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
import Base.deepcopy_internal
# A wrapper for Module. The only purpose is a specialization of deepcopy since otherwise the "mod" property in the PerturbationModule brakes multithreaded MCMC
struct ModuleWrapper
m::Module
end
function deepcopy_internal(x::ModuleWrapper, stackdict::IdDict)
if haskey(stackdict, x)
return stackdict[x]::ModuleWrapper
end
y = S(x.m)
stackdict[x] = y
return y
end

deepcopy_internal(x::Module, stackdict::IdDict) = x

# Model Types. The template args are required for inference for cache/perturbation solutions
struct PerturbationModel{MaxOrder,N_y,N_x,N_ϵ,N_z,N_p,HasΩ,T1,T2}
mod::Module
mod::ModuleWrapper

# Could extract from type, but here for simplicity
max_order::Int64
Expand All @@ -22,17 +35,17 @@ end
# Construct from a module. Inherently type unstable, so use function barrier from return type
function PerturbationModel(mod)
return PerturbationModel{mod.max_order,mod.n_y,mod.n_x,mod.n_ϵ,mod.n_z,mod.n_p,
mod.has_Ω,typeof(mod.η),typeof(mod.Q)}(mod, mod.max_order,
mod.n_y, mod.n_x,
mod.n_p, mod.n_ϵ,
mod.n_z, mod.has_Ω,
mod.η, mod.Q)
mod.has_Ω,typeof(mod.η),typeof(mod.Q)}(ModuleWrapper(mod),
mod.max_order, mod.n_y,
mod.n_x, mod.n_p,
mod.n_ϵ, mod.n_z,
mod.has_Ω, mod.η, mod.Q)
end

# TODO: Add in latex stuff for the mod.H_latex,
function Base.show(io::IO, ::MIME"text/plain", m::PerturbationModel) where {T}
return print(io,
"Perturbation Model: n_y = $(m.n_y), n_x = $(m.n_x), n_p = $(m.n_p), n_ϵ = $(m.n_ϵ), n_z = $(m.n_z)\n y = $(m.mod.y_symbols) \n x = $(m.mod.x_symbols) \n p = $(m.mod.p_symbols)")
"Perturbation Model: n_y = $(m.n_y), n_x = $(m.n_x), n_p = $(m.n_p), n_ϵ = $(m.n_ϵ), n_z = $(m.n_z)\n y = $(m.mod.m.y_symbols) \n x = $(m.mod.m.x_symbols) \n p = $(m.mod.m.p_symbols)")
end

# Buffers for the solvers to reduce allocations
Expand Down Expand Up @@ -381,11 +394,11 @@ maybe_diagonal(x::AbstractVector) = MvNormal(Diagonal(abs2.(x)))
maybe_diagonal(x) = x # otherwise, just return raw. e.g. nothing

function FirstOrderPerturbationSolution(retcode, m::PerturbationModel, c::SolverCache)
return FirstOrderPerturbationSolution(; retcode, m.mod.x_symbols, m.mod.y_symbols,
m.mod.u_symbols, m.mod.p_symbols, c.p_d_symbols,
m.n_x, m.n_y, m.n_p, m.n_ϵ, m.n_z, c.Q, c.η, c.y,
c.x, c.B, D = maybe_diagonal(c.Ω), c.g_x,
A = c.h_x, C = c.C_1,
return FirstOrderPerturbationSolution(; retcode, m.mod.m.x_symbols, m.mod.m.y_symbols,
m.mod.m.u_symbols, m.mod.m.p_symbols,
c.p_d_symbols, m.n_x, m.n_y, m.n_p, m.n_ϵ, m.n_z,
c.Q, c.η, c.y, c.x, c.B, D = maybe_diagonal(c.Ω),
c.g_x, A = c.h_x, C = c.C_1,
x_ergodic = MvNormal(zeros(m.n_x), c.V), # construct with PDMat already taken cholesky
c.Γ)
end
Expand Down Expand Up @@ -434,10 +447,11 @@ Base.@kwdef struct SecondOrderPerturbationSolution{T1<:AbstractVector,T2<:Abstra
end

function SecondOrderPerturbationSolution(retcode, m::PerturbationModel, c::SolverCache)
return SecondOrderPerturbationSolution(; retcode, m.mod.x_symbols, m.mod.y_symbols,
m.mod.u_symbols, m.mod.p_symbols, c.p_d_symbols,
m.n_x, m.n_y, m.n_p, m.n_ϵ, m.n_z, c.Q, c.η, c.y,
c.x, c.B, D = maybe_diagonal(c.Ω), c.Γ, c.g_x,
A_1 = c.h_x, c.g_xx, A_2 = 0.5 * c.h_xx, c.g_σσ,
A_0 = 0.5 * c.h_σσ, c.C_1, c.C_0, c.C_2)
return SecondOrderPerturbationSolution(; retcode, m.mod.m.x_symbols, m.mod.m.y_symbols,
m.mod.m.u_symbols, m.mod.m.p_symbols,
c.p_d_symbols, m.n_x, m.n_y, m.n_p, m.n_ϵ, m.n_z,
c.Q, c.η, c.y, c.x, c.B, D = maybe_diagonal(c.Ω),
c.Γ, c.g_x, A_1 = c.h_x, c.g_xx,
A_2 = 0.5 * c.h_xx, c.g_σσ, A_0 = 0.5 * c.h_σσ,
c.C_1, c.C_0, c.C_2)
end
34 changes: 17 additions & 17 deletions test/first_order_perturbation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,35 +58,35 @@ end
c = SolverCache(m, Val(1), p_d)

# Create parameter vector in the same ordering the internal algorithms would
p = order_vector_by_symbols(merge(p_d, p_f), m.mod.p_symbols)
p = order_vector_by_symbols(merge(p_d, p_f), m.mod.m.p_symbols)

y = zeros(m.n_y)
x = zeros(m.n_x)

m.mod.ȳ!(y, p)
m.mod.x̄!(x, p)
m.mod.m.ȳ!(y, p)
m.mod.m.x̄!(x, p)
@test y [5.936252888048733, 6.884057971014498]
@test x [47.39025414828825, 0.0]

m.mod.H_yp!(c.H_yp, y, x, p)
m.mod.m.H_yp!(c.H_yp, y, x, p)
@test c.H_yp [0.028377570562199098 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0]

m.mod.H_y!(c.H_y, y, x, p)
m.mod.m.H_y!(c.H_y, y, x, p)
@test c.H_y [-0.0283775705621991 0.0; 1.0 -1.0; 0.0 1.0; 0.0 0.0]

m.mod.H_xp!(c.H_xp, y, x, p)
m.mod.m.H_xp!(c.H_xp, y, x, p)
@test c.H_xp [0.00012263591151906127 -0.011623494029190608
1.0 0.0
0.0 0.0
0.0 1.0]

m.mod.H_x!(c.H_x, y, x, p)
m.mod.m.H_x!(c.H_x, y, x, p)
@test c.H_x [0.0 0.0
-0.98 0.0
-0.07263157894736837 -6.884057971014498
0.0 -0.2]

m.mod.Ψ!(c.Ψ, y, x, p)
m.mod.m.Ψ!(c.Ψ, y, x, p)
@test c.Ψ[1]
[-0.009560768753410337 0.0 0.0 0.0 -2.0658808482697935e-5 0.0019580523687917364 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
Expand All @@ -107,31 +107,31 @@ end
0.0 0.0 0.0 0.0 0.0 0.0 -0.07263157894736837 -6.884057971014498]
@test c.Ψ[4] zeros(8, 8)

m.mod.Γ!(c.Γ, p)
m.mod.m.Γ!(c.Γ, p)
@test c.Γ [0.01]

m.mod.Ω!(c.Ω, p)
m.mod.m.Ω!(c.Ω, p)
@test c.Ω [0.01, 0.01]

# The derivative ones dispatch by the derivative symbol
fill_array_by_symbol_dispatch(m.mod.H_x_p!, c.H_x_p, p_d_symbols, y, x, p)
fill_array_by_symbol_dispatch(m.mod.m.H_x_p!, c.H_x_p, p_d_symbols, y, x, p)
@test c.H_x_p [[0.0 0.0
0.0 0.0
-0.4255060477077458 -26.561563542978472
0.0 0.0], [0.0 0.0
0.0 0.0;
0.0 0.0;
0.0 0.0]]
fill_array_by_symbol_dispatch(m.mod.H_yp_p!, c.H_yp_p, p_d_symbols, y, x, p)
fill_array_by_symbol_dispatch(m.mod.m.H_yp_p!, c.H_yp_p, p_d_symbols, y, x, p)
@test c.H_yp_p [[0.011471086498795562 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0],
[0.029871126907577997 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0]]

fill_array_by_symbol_dispatch(m.mod.H_y_p!, c.H_y_p, p_d_symbols, y, x, p)
fill_array_by_symbol_dispatch(m.mod.m.H_y_p!, c.H_y_p, p_d_symbols, y, x, p)
@test c.H_y_p [[0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0], [0.0 0.0
0.0 0.0;
0.0 0.0; 0.0 0.0]]

fill_array_by_symbol_dispatch(m.mod.H_xp_p!, c.H_xp_p, p_d_symbols, y, x, p)
fill_array_by_symbol_dispatch(m.mod.m.H_xp_p!, c.H_xp_p, p_d_symbols, y, x, p)
@test c.H_xp_p [[0.000473180436623283 -0.06809527035753198
0.0 0.0
0.0 0.0
Expand All @@ -140,16 +140,16 @@ end
0.0 0.0
0.0 0.0]]

fill_array_by_symbol_dispatch(m.mod.Γ_p!, c.Γ_p, p_d_symbols, p)
fill_array_by_symbol_dispatch(m.mod.m.Γ_p!, c.Γ_p, p_d_symbols, p)

@test c.Γ_p [[0.0], [0.0]]

fill_array_by_symbol_dispatch(m.mod.H_p!, c.H_p, p_d_symbols, y, x, p)
fill_array_by_symbol_dispatch(m.mod.m.H_p!, c.H_p, p_d_symbols, y, x, p)

@test c.H_p [[-0.06809527035753199, 0.0, -26.561563542978472, 0.0],
[-0.1773225633743801, 0.0, 0.0, 0.0]]

fill_array_by_symbol_dispatch(m.mod.Ω_p!, c.Ω_p, p_d_symbols, p)
fill_array_by_symbol_dispatch(m.mod.m.Ω_p!, c.Ω_p, p_d_symbols, p)
@test c.Ω_p [[0.0, 0.0], [0.0, 0.0]]
end

Expand Down
2 changes: 1 addition & 1 deletion test/make_perturbation_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ using DifferentiableStateSpaceModels, Symbolics, Test
m = PerturbationModel(Main.rbc_temp)
@test m.n_y == n_y
@test m.max_order == max_order
@test m.mod.n_z == n_z
@test m.mod.m.n_z == n_z
# Note that this is inherently dynamic and cannot be inferred, so @inferred PerturbationModel(Main.rbc_observables) would fail

c = SolverCache(m, Val(2), [:a, :b, :c]) # the exact symbol names won't matter for inference
Expand Down

0 comments on commit 8a17d9a

Please sign in to comment.