Skip to content

Inconsistent types from MTK+DI #228

@wupeifan

Description

@wupeifan

This might be related to #225. It should be some bug from build_function.
Looking at the results, DI(DI(f))(x) does not generate the same data type as DI(f)(x).

using ModelingToolkit
using ForwardDiff2: DI
using ChainRules

@scalar_rule(^(x, y), (y * x^(y - 1), Ω * log(x)))

function equations(vars, theta, n_exo, n_endo)
    # X: (k, z)
    # Y: (c)
    XP = vars[1:n_exo]
    X = vars[(n_exo + 1):(2 * n_exo)]
    YP = vars[(2 * n_exo + 1):(2 * n_exo + n_endo)]
    Y = vars[(2 * n_exo + n_endo + 1):(2 * (n_exo + n_endo))]
    beta, alpha, delta, rho, sigma = theta
    equ = [ 1 / Y[1] - beta / YP[1] * (alpha * exp(XP[2]) * XP[1] ^ (alpha - 1) + 1 - delta),
            Y[1] + XP[1] - (1 - delta) * X[1] - exp(X[2]) * X[1] ^ alpha,
            XP[2] - rho * X[2]]
    return equ
end

@variables beta, alpha, delta, rho, sigma
@variables k, z, c, k_p, z_p, c_p

X_sym = [k_p, z_p, k, z, c_p, c]
θ_sym = [beta, alpha, delta, rho, sigma]

X_val = [1, 0, 1, 0, 1, 1]
θ_val = [0.99, 0.3, 0.1, 0.8, 0.01]

## THE FOLLOWING CODE RETURNS "expressions"
∇²H = DI(DI(X_arg -> equations(X_arg, θ_sym, 2, 1)))(X_sym)
f_∇²H = eval(ModelingToolkit.build_function(∇²H, vcat(X_sym, θ_sym), linenumbers = false)[1])
julia>f_∇²H(vcat(X_val, θ_val))
108-element Array{Expression,1}:

## However we should expect numbers, like this one
∇H = DI(X_arg -> equations(X_arg, θ_sym, 2, 1))(X_sym)
f_∇H = eval(ModelingToolkit.build_function(∇H, vcat(X_sym, θ_sym), linenumbers = false)[1])
julia>f_∇H(vcat(X_val, θ_val))
18-element Array{Float64,1}:

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions