Skip to content

Enzyme failing in Julia: ERROR: LoadError: Enzyme execution failed. #1571

@miguelborrero5

Description

@miguelborrero5

Hi there,

Opening this issue to see whether I could get some help with making Reverse mode in autodiff work. I made it work on trivial examples, even containing conditional statements but for some reason it won’t work on the function I am interested. This function takes in various parameters and then a vector of floats over which we want to calculate the gradient. It is a long function so for now I will only provide a reduced example. I am sorry if this is still too long. Below you can find self-contained code which includes a function at the end that calls autodiff using Enzyme. Moreover, after the code I post the error message.

using Distributions 
using Polynomials
using Combinatorics
using Statistics
using FastGaussQuadrature
using Interpolations
using LinearAlgebra
using Printf
using BenchmarkTools
using StaticArrays
using Enzyme

struct SpecFormat
    special_banks::Vector{Int64} # list of banks that are special
    home_bank::Int64
    max_banks_in_choice_set::Int64 # max number of banks in the choice set 

    hermite_order::Int64
    bernstein_order::Int64

    initial_search_cost_X::Vector{Union{Symbol, String}}
    per_bank_search_cost_X::Vector{Union{Symbol, String}}
    ρ_X::Vector{Union{Symbol, String}}
    γ_X::Vector{Union{Symbol, String}}
    γ_std_X::Vector{Union{Symbol, String}}
    κ_X::Vector{Union{Symbol, String}}
    initial_mean_shifters::Vector{Union{Symbol, String}}
    initial_std_shifters::Vector{Union{Symbol, String}}
    final_mean_shifters::Vector{Union{Symbol, String}}
    final_std_shifters::Vector{Union{Symbol, String}}

end

struct SubParamIndices
    κ::Int64
    γ::Int64
    γ_std::Int64
    initial_search_cost::Int64
    ρ::Int64
    initial_bank_μ::Int64
    initial_bank_σ::Int64
    final_bank_μ::Int64
    final_bank_σ::Int64
    σinv_choice::Int64
    σinv_start_search::Int64

    per_bank_search_cost::Vector{Int64}
    bank_params::Vector{Vector{Int64}}
    jump_bid_params::Vector{Int64}
    implied_jump_bid_params::Vector{Int64}

    n_params::Int64

end

struct ParamIndices
    bank_params::Vector{Vector{Int64}}
    jump_bid_params::Vector{Int64}
    implied_jump_bid_params::Vector{Int64}
    initial_search_cost_X::Vector{Int64}
    per_bank_search_cost_X::Vector{Int64}
    ρ_X::Vector{Int64}
    γ_X::Vector{Int64}
    γ_std_X::Vector{Int64}
    κ_X::Vector{Int64}
    bank_mean_shifters::Vector{Int64}
    bank_std_shifters::Vector{Int64}
    initial_search_cost_X_indices::Vector{Int64}
    ρ_X_indices::Vector{Int64}
    γ_X_indices::Vector{Int64}
    γ_std_X_indices::Vector{Int64}
    κ_X_indices::Vector{Int64}
    initial_mean_indices::Vector{Int64}
    initial_std_indices::Vector{Int64}
    final_mean_indices::Vector{Int64}
    final_std_indices::Vector{Int64}
    σinv_choice::Int64
    σinv_start_search::Int64
    n_params::Int64
end

function cumtrapz_upper(x::Vector{Float64}, y::AbstractArray{TV, 1})::Vector{TV} where TV
    z::Vector{TV} = zeros(Float64, length(x))
    z[1] = 0.0
    @inbounds @fastmath for i = 1:length(x)-1
        z[i+1] = z[i] + (x[i+1] - x[i]) * (y[i+1] + y[i]) 
    end
    z = z[end] .- z
    return 0.5 * z
end

function hermite5_cdf(a::AbstractArray{TV, 1}, μ::TV, σ::TV) where TV
    poly_coeffs::Vector{TV} = [-48 * a[1] * a[2] - 24 * sqrt(2) * a[2] * a[3] + 8 * sqrt(6) * a[1] * a[4] - 24 * sqrt(3) * a[3] * a[4] + 4 * sqrt(6) * a[2] * a[5] - 36 * a[4] * a[5], 
        -24 * a[2]^2 - 24 * sqrt(2) * a[1] * a[3] - 12 * a[3]^2 - 24 * a[4]^2 + 12 * sqrt(6) * a[1] * a[5] - 12 * sqrt(3) * a[3] * a[5] - 15 * a[5]^2, 
        -24 * sqrt(2) * a[2] * a[3] - 8 * sqrt(6) * a[1] * a[4] + 8 * sqrt(6) * a[2] * a[5] - 36 * a[4] * a[5],
        -12 * a[3]^2 - 8 * sqrt(6) * a[2] * a[4] + 4 * a[4]^2 - 4 * sqrt(6) * a[1] * a[5] + 8 * sqrt(3) * a[3] * a[5] - 17 * a[5]^2,
        -8 * sqrt(3) * a[3] * a[4] - 4 * sqrt(6) * a[2] * a[5] + 12 * a[4] * a[5],
        -4 * a[4]^2 - 4 * sqrt(3) * a[3] * a[5] + 5 * a[5]^2,
        -4 * a[4] * a[5],
        -a[5]^2]
    func(x) = σ * Polynomial(poly_coeffs).((x .- μ) / σ) .* pdf.(Normal(μ, σ), x) ./ 24.0 .+ cdf.(Normal(μ, σ), x)
    return func
end

function hermite5_pdf(a::AbstractArray{TV, 1}, μ::TV, σ::TV) where TV
    poly_coeffs::Vector{TV} = [a[1] - a[3] / sqrt(2) + sqrt(3/8) * a[5], 
        a[2] - sqrt(3/2) * a[4], 
        a[3] / sqrt(2) - sqrt(3/2) * a[5], 
        a[4] / sqrt(6), 
        a[5] / sqrt(24)]
    func(x) = (Polynomial(poly_coeffs).((x .- μ) / σ)).^2 .* pdf.(Normal(μ, σ), x)
    return func
end

function bernstein3(a::AbstractArray{TV, 1}; upper::Float64 = 1.0, lower::Float64 = 0.0) where TV
    base_func(x) = a[1] * (1 - x)^3 + 3 * a[2] * x * (1 - x)^2 + 3 * a[3] * x^2 * (1 - x) + a[4] * x^3
    func(x) = base_func((x - lower) / (upper - lower))
    return func
end

function solve_cubic_eq(poly::AbstractVector{Complex{T}}) where T
    # Cubic equation solver for complex polynomial (degree=3)
    # http://en.wikipedia.org/wiki/Cubic_function   Lagrange's method
    a1  =  1 / poly[4]
    E1  = -poly[3]*a1
    E2  =  poly[2]*a1
    E3  = -poly[1]*a1
    s0  =  E1
    E12 =  E1*E1
    A   =  2*E1*E12 - 9*E1*E2 + 27*E3 # = s1^3 + s2^3
    B   =  E12 - 3*E2                 # = s1 s2
    # quadratic equation: z^2 - Az + B^3=0  where roots are equal to s1^3 and s2^3
    Δ = sqrt(A*A - 4*B*B*B)
    if real(conj(A)*Δ)>=0 # scalar product to decide the sign yielding bigger magnitude
        s1 = exp(log(0.5 * (A + Δ)) * (1/3))
    else
        s1 = exp(log(0.5 * (A - Δ)) * (1/3))
    end
    if s1 == 0
        s2 = s1
    else
        s2 = B / s1
    end
    zeta1 = complex(-0.5, sqrt(T(3.0))*0.5)
    zeta2 = conj(zeta1)
    # return third*(s0 + s1 + s2), third*(s0 + s1*zeta2 + s2*zeta1), third*(s0 + s1*zeta1 + s2*zeta2)

    sol1 = (1/3) * (s0 + s1 + s2)
    sol2 = (1/3) * (s0 + s1 * zeta2 + s2 * zeta1)
    sol3 = (1/3) * (s0 + s1 * zeta1 + s2 * zeta2)

    if abs(imag(sol1)) < 1e-8 && real(sol1) > 0.0 && real(sol1) < 1.0
        return real(sol1)
    elseif abs(imag(sol2)) < 1e-8 && real(sol2) > 0.0 && real(sol2) < 1.0
        return real(sol2)
    elseif abs(imag(sol3)) < 1e-8 && real(sol3) > 0.0 && real(sol3) < 1.0
        return real(sol3)
    else
        return NaN
    end
end

function bernstein3_inv(p::AbstractArray{TV, 1}, val::Float64; upper::Float64 = 1.0, lower::Float64 = 0.0) where TV
    # Get this in the form ax^3 + bx^2 + cx + d = 0
    a = p[4] + 3 * p[2] - 3 * p[3] - p[1]
    b = 3 * p[1] - 6 * p[2] + 3 * p[3]
    c = 3 * p[2] - 3 * p[1]
    d = p[1] - val
    coeffs = [d, c, b, a]

    if d >= 0.0 # (poly(0.0))
        return lower
    elseif a + b + c + d <= 0.0 # (poly(1.0))
        return upper
    end

    x = solve_cubic_eq(Complex.(coeffs))
    return x * (upper - lower) + lower
end

# From PolynomialRoots.jl
function solve_cubic_eq(poly::AbstractVector{Complex{T}}) where T
    # Cubic equation solver for complex polynomial (degree=3)
    # http://en.wikipedia.org/wiki/Cubic_function   Lagrange's method
    a1  =  1 / poly[4]
    E1  = -poly[3]*a1
    E2  =  poly[2]*a1
    E3  = -poly[1]*a1
    s0  =  E1
    E12 =  E1*E1
    A   =  2*E1*E12 - 9*E1*E2 + 27*E3 # = s1^3 + s2^3
    B   =  E12 - 3*E2                 # = s1 s2
    # quadratic equation: z^2 - Az + B^3=0  where roots are equal to s1^3 and s2^3
    Δ = sqrt(A*A - 4*B*B*B)
    if real(conj(A)*Δ)>=0 # scalar product to decide the sign yielding bigger magnitude
        s1 = exp(log(0.5 * (A + Δ)) * (1/3))
    else
        s1 = exp(log(0.5 * (A - Δ)) * (1/3))
    end
    if s1 == 0
        s2 = s1
    else
        s2 = B / s1
    end
    zeta1 = complex(-0.5, sqrt(T(3.0))*0.5)
    zeta2 = conj(zeta1)
    # return third*(s0 + s1 + s2), third*(s0 + s1*zeta2 + s2*zeta1), third*(s0 + s1*zeta1 + s2*zeta2)

    sol1 = (1/3) * (s0 + s1 + s2)
    sol2 = (1/3) * (s0 + s1 * zeta2 + s2 * zeta1)
    sol3 = (1/3) * (s0 + s1 * zeta1 + s2 * zeta2)

    if abs(imag(sol1)) < 1e-8 && real(sol1) > 0.0 && real(sol1) < 1.0
        return real(sol1)
    elseif abs(imag(sol2)) < 1e-8 && real(sol2) > 0.0 && real(sol2) < 1.0
        return real(sol2)
    elseif abs(imag(sol3)) < 1e-8 && real(sol3) > 0.0 && real(sol3) < 1.0
        return real(sol3)
    else
        return NaN
    end
end

function bernstein3_inv(p::AbstractArray{TV, 1}, val::Float64; upper::Float64 = 1.0, lower::Float64 = 0.0) where TV
    # Get this in the form ax^3 + bx^2 + cx + d = 0
    a = p[4] + 3 * p[2] - 3 * p[3] - p[1]
    b = 3 * p[1] - 6 * p[2] + 3 * p[3]
    c = 3 * p[2] - 3 * p[1]
    d = p[1] - val
    coeffs = [d, c, b, a]

    if d >= 0.0 # (poly(0.0))
        return lower
    elseif a + b + c + d <= 0.0 # (poly(1.0))
        return upper
    end

    x = solve_cubic_eq(Complex.(coeffs))
    return x * (upper - lower) + lower
end

# Here, we take in a set of choice sets, each of with can be repeated N times.
function logit_probabilities!(inside_probabilities::Array{TV, 3}, outside_probabilities::Matrix{TV}, utilities::Array{TV, 3}, N::Vector{Int64}, σinv::TV) where TV

    for uJ_index in axes(utilities, 2)
        for bias_index in axes(utilities, 3)
            this_exp_probs = @views exp.(utilities[:, uJ_index, bias_index] * σinv)
            inv_denom = 1.0 / (1.0 + dot(this_exp_probs, N))
            inside_probabilities[:, uJ_index, bias_index] .= inv_denom * this_exp_probs
            outside_probabilities[uJ_index, bias_index] = inv_denom
        end
    end

    return nothing
end

function partial_interpolate(x::AbstractArray{Float64, 1}, y::AbstractArray{TV, 1}, idx::Int64, val::Float64) where TV
    z::TV = 0.0
    if idx == 0
        z = x[1]
    elseif idx == length(x)
        z = x[end]
    else
        z = ((y[idx+1] - y[idx]) * (val - x[idx]) / (x[idx+1] - x[idx])) + y[idx]
    end

    return z
end

function get_spec()::SpecFormat

    special_banks::Vector{Int64} = [1, 14, 16] 
    home_bank::Int64 = 12
    max_banks_in_choice_set::Int64 = 3

    hermite_order::Int64 = 5
    bernstein_order::Int64 = 4

    initial_search_cost_X::Vector{Union{Symbol, String}} = [:ones, :x_educ_high, :x_income]
    per_bank_search_cost_X::Vector{Union{Symbol, String}} = [:ones, :x_educ_high, :x_income, :s_n_branches]
    ρ_X::Vector{Union{Symbol, String}} = [:ones, :x_educ_high, :x_income]
    γ_X::Vector{Union{Symbol, String}} = [:ones, :x_educ_high, :x_income]
    γ_std_X::Vector{Union{Symbol, String}} = [:ones]
    κ_X::Vector{Union{Symbol, String}} = [:ones, :x_educ_high, :x_income]
    initial_mean_shifters::Vector{Union{Symbol, String}} = [:ones, :x_educ_high, :x_income, :b_term, :b_amount]
    initial_std_shifters::Vector{Union{Symbol, String}} = [:ones, :b_term, :b_amount]
    final_mean_shifters::Vector{Union{Symbol, String}} = [:ones, :x_educ_high, :x_income, :b_term, :b_amount]
    final_std_shifters::Vector{Union{Symbol, String}} = [:ones, :b_term, :b_amount]

    spec::SpecFormat = SpecFormat(special_banks, home_bank, max_banks_in_choice_set, hermite_order, bernstein_order,
        initial_search_cost_X, per_bank_search_cost_X, ρ_X, γ_X, γ_std_X, κ_X, initial_mean_shifters, initial_std_shifters,
        final_mean_shifters, final_std_shifters)
    return spec

end

function get_indices(grid::AbstractArray{Float64, 1}, vals::AbstractArray{Float64, 1})::Vector{Int64}
    indices = zeros(Int64, length(vals))
    for (i, val) in enumerate(vals)
        indices[i] = searchsortedlast(grid, val)
    end

    return indices
end

function get_param_indices(spec::SpecFormat)::Tuple{ParamIndices, SubParamIndices}
    param_val::Int64 = 0

    bank_params::Vector{Vector{Int64}} = []
    for b = 1:(length(spec.special_banks)+1)
        this_bank_params::Vector{Int64}, param_val = param_update(spec.hermite_order, param_val)
        push!(bank_params, this_bank_params)
    end

    jump_bid_params::Vector{Int64}, param_val = param_update(spec.hermite_order, param_val)
    implied_jump_bid_params::Vector{Int64}, param_val = param_update(spec.bernstein_order, param_val)

    initial_search_cost_X::Vector{Int64}, param_val = param_update(length(spec.initial_search_cost_X), param_val)
    per_bank_search_cost_X::Vector{Int64}, param_val = param_update(length(spec.per_bank_search_cost_X), param_val)
    ρ_X::Vector{Int64}, param_val = param_update(length(spec.ρ_X), param_val)
    γ_X::Vector{Int64}, param_val = param_update(length(spec.γ_X), param_val)
    γ_std_X::Vector{Int64}, param_val = param_update(length(spec.γ_std_X), param_val)
    κ_X::Vector{Int64}, param_val = param_update(length(spec.κ_X), param_val)

    @assert(length(spec.initial_mean_shifters) == length(spec.final_mean_shifters))
    @assert(length(spec.initial_std_shifters) == length(spec.final_std_shifters))
    bank_mean_shifters::Vector{Int64}, param_val = param_update(length(spec.initial_mean_shifters), param_val)
    bank_std_shifters::Vector{Int64}, param_val = param_update(length(spec.initial_std_shifters), param_val)
    σinv_choice::Int64, param_val = param_update(1, param_val; return_as_vector = false)
    σinv_start_search::Int64, param_val = param_update(1, param_val; return_as_vector = false)

    # Now get the inidices: Will have one X matrix at the individual level (N_individuals x total parameters)
    # and then indices for which subsets of this X to use for various parameters
    all_X = sort!(union(spec.initial_search_cost_X, spec.ρ_X, spec.γ_X, spec.γ_std_X, spec.κ_X, 
        spec.initial_mean_shifters, spec.initial_std_shifters, spec.final_mean_shifters, spec.final_std_shifters)) # per-bank search cost will be at the individual-bank level
    initial_search_cost_X_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.initial_search_cost_X]
    ρ_X_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.ρ_X]
    γ_X_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.γ_X]
    γ_std_X_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.γ_std_X]
    κ_X_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.κ_X]
    initial_mean_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.initial_mean_shifters]
    initial_std_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.initial_std_shifters]
    final_mean_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.final_mean_shifters]
    final_std_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.final_std_shifters]

    param_indices = ParamIndices(bank_params, jump_bid_params, implied_jump_bid_params, 
        initial_search_cost_X, per_bank_search_cost_X, ρ_X, γ_X, γ_std_X, κ_X,
        bank_mean_shifters, bank_std_shifters,
        initial_search_cost_X_indices, ρ_X_indices, γ_X_indices, γ_std_X_indices, κ_X_indices, 
        initial_mean_indices, initial_std_indices, final_mean_indices, final_std_indices, 
        σinv_choice, σinv_start_search, param_val)

    sub_param_val::Int64 = 0
    sub_κ::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
    sub_γ::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
    sub_γ_std::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
    sub_initial_search_cost::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
    sub_ρ::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
    sub_initial_bank_μ::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
    sub_initial_bank_σ::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
    sub_final_bank_μ::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
    sub_final_bank_σ::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
    sub_σinv_choice::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
    sub_σinv_start_search::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
    sub_per_bank_search_cost::Vector{Int64}, sub_param_val = param_update(length(spec.special_banks)+1, sub_param_val)
    sub_bank_params::Vector{Vector{Int64}} = []
    for b = 1:(length(spec.special_banks)+1)
        this_bank_params::Vector{Int64}, sub_param_val = param_update(spec.hermite_order, sub_param_val)
        push!(sub_bank_params, this_bank_params)
    end
    sub_jump_bid_params::Vector{Int64}, sub_param_val = param_update(spec.hermite_order, sub_param_val)
    sub_implied_jump_bid_params::Vector{Int64}, sub_param_val = param_update(spec.bernstein_order, sub_param_val)

    sub_param_indices = SubParamIndices(sub_κ, sub_γ, sub_γ_std, sub_initial_search_cost, sub_ρ, 
        sub_initial_bank_μ, sub_initial_bank_σ, sub_final_bank_μ, sub_final_bank_σ, 
        sub_σinv_choice, sub_σinv_start_search, sub_per_bank_search_cost, sub_bank_params, 
        sub_jump_bid_params, sub_implied_jump_bid_params, sub_param_val)
        

    return (param_indices, sub_param_indices)
end

function param_update(N::Int64, param_val::Int64; return_as_vector::Bool = true)
    if N > 1 || return_as_vector
        indices::Vector{Int64} = collect(1:N) .+ param_val
        param_val += N
        return indices, param_val
    else
        indices_int::Int64 = param_val + 1
        param_val += 1
        return indices_int, param_val
    end
end

function get_ll_single(params::AbstractArray{TV, 1}, param_indices::SubParamIndices, uJ_grid::Vector{Float64}, base_nodes::SVector{9, Float64}, weights::SVector{9, Float64}
        , u_current_idx::Int64, remaining_term::Float64, amount::Float64, chosen_bank::Int64, current_monthly_payment::Float64, final_monthly_payment::Float64,
    search_type::Int64, choice_sets::Vector{Vector{Int64}}, choice_sets_N::Vector{Int64}, chosen_choice_set::Int64) where TV


    # Initialize the values of the frictions -- this applies for everyone
    κ::TV = params[param_indices.κ] 
    γ::TV = params[param_indices.γ] 
    γ_std::TV = params[param_indices.γ_std]
    initial_search_cost::TV = params[param_indices.initial_search_cost] 
    ρ::TV = params[param_indices.ρ] 
    initial_bank_μ::TV = params[param_indices.initial_bank_μ]  
    initial_bank_σ::TV = params[param_indices.initial_bank_σ] 
    final_bank_μ::TV = params[param_indices.final_bank_μ] 
    final_bank_σ::TV = params[param_indices.final_bank_σ] 
    σinv_choice::TV = params[param_indices.σinv_choice]
    σinv_start_search::TV = params[param_indices.σinv_start_search]
    
    ## Now the loop begins
    # Note that the cdf of zero-profit rates is written in terms of monthly payments if A = 1, scaled by number of 
    # payments. So, a value of m corresponds to a monthly payment of m * A / T.
    # 
    # We know that monthly utility = (delta - m * A / T) - kappa / NPVrate(T), 
    # So, Pr(utility <= u) = Pr((delta - m * A / T) - kappa/NPV <= u) = Pr(scaled monthly payment >= (T/A) * (delta - (u + kappa/NPV)))
    # The bias is in terms of monthly payments. So, we just add the bias.
    β_customer_m::Float64 = 0.95^(1/12)
    per_bank_search_costs::Vector{TV} = params[param_indices.per_bank_search_cost]
    npv_scale::Float64 = (1.0 - β_customer_m^remaining_term) / (1 - β_customer_m)
    overall_scale::Float64 = remaining_term / amount
    bias_grid = base_nodes * γ_std * sqrt(2) .+ γ # in units of dollars
    scaled_κ::TV = κ / npv_scale # in units of dollars per month


    # Generate the cdfs and pdfs for the banks: cdfs at the time of search, and both later if needed 

    # Maps utilities ($) to utilities ($) <--- could rethink whether this is the best map
    implied_jump_bid::Function = bernstein3(params[param_indices.implied_jump_bid_params]; upper = maximum(uJ_grid) + 2.0, lower = minimum(uJ_grid) - 2.0)
    implied_jump_bid_inv(val::Float64) = bernstein3_inv(params[param_indices.implied_jump_bid_params], val; upper = maximum(uJ_grid) + 2.0, lower = minimum(uJ_grid) - 2.0)

    
    initial_bank_cdf = [x -> 1.0 .- hermite5_cdf(params[a], initial_bank_μ, initial_bank_σ)(-overall_scale .* (x .+ scaled_κ)) for a in param_indices.bank_params]
    initial_jump_bid_cdf::Function = x -> 1.0 .- hermite5_cdf(params[param_indices.jump_bid_params], initial_bank_μ, initial_bank_σ)(-overall_scale .* x) 
    initial_jump_bid_pdf::Function = x -> overall_scale .* hermite5_pdf(params[param_indices.jump_bid_params], initial_bank_μ, initial_bank_σ)(-overall_scale .* x)
    initial_home_bank_cdf(x) = initial_jump_bid_cdf(implied_jump_bid.(x))

    final_bank_cdf = [x -> 1.0 .- hermite5_cdf(params[a], final_bank_μ, final_bank_σ)(@. -overall_scale * (x + scaled_κ)) for a in param_indices.bank_params]
    final_bank_pdf = [x -> overall_scale .* hermite5_pdf(params[a], final_bank_μ, final_bank_σ)(@. -overall_scale * (x + scaled_κ)) for a in param_indices.bank_params]
    final_jump_bid_cdf::Function = x -> 1.0 .- hermite5_cdf(params[param_indices.jump_bid_params], final_bank_μ, final_bank_σ)(@. -overall_scale * x) 
    final_jump_bid_pdf::Function = x -> overall_scale .* hermite5_pdf(params[param_indices.jump_bid_params], final_bank_μ, final_bank_σ)(@. -overall_scale * x)
    final_home_bank_cdf(x) = final_jump_bid_cdf(implied_jump_bid.(x))


    refi_at_home_bank::Bool = chosen_bank == 0
    u_current::Float64 = -current_monthly_payment
    u_final::TV = -final_monthly_payment - scaled_κ * refi_at_home_bank


    choice_set_utilities::Array{TV, 3} = zeros(length(choice_sets), length(uJ_grid), length(bias_grid)) ### CONTAINER 3D-FLOAT

    base_0::Vector{TV} = zeros(length(uJ_grid)) ### CONTAINER 1D FLOAT SIZE UJ_GRID
    base_1::Vector{TV} = zeros(length(uJ_grid)) ### CONTAINER 1D FLOAT SIZE UJ_GRID
    to_add::Vector{TV} = zeros(length(uJ_grid)) ### CONTAINER 1D FLOAT SIZE UJ_GRID


    initial_bank_cdf_vals::Matrix{Float64} = zeros(Float64, length(initial_bank_cdf), length(uJ_grid))
    one_minus_initial_bank_cdf_vals::Matrix{TV} = zeros(Float64, length(initial_bank_cdf), length(uJ_grid))


    @inbounds for (bias_index, bias) in enumerate(bias_grid)

        biased_uJ = uJ_grid .+ bias 

        for b = 1:length(initial_bank_cdf)
            initial_bank_cdf_vals[b, :] .= initial_bank_cdf[b](biased_uJ) 
            one_minus_initial_bank_cdf_vals[b, :] .= @views 1.0 .- initial_bank_cdf_vals[b, :]
        end

        initial_home_bank_cdf_vals::Vector{Float64} = initial_home_bank_cdf(biased_uJ)
        good_indices::BitVector = (initial_home_bank_cdf_vals[end] .- initial_home_bank_cdf_vals) .>= 1e-7
        
        @inbounds for (choice_set_index, choice_set) in enumerate(choice_sets)

            base_0 .= @views initial_bank_cdf_vals[choice_set[1], :]
            for (b_idx, b) in enumerate(choice_set)
                b_idx > 1 || continue
                base_0 .*= @views initial_bank_cdf_vals[b, :]
            end
            cumtrapz_0 = cumtrapz_upper(uJ_grid, base_0)

            base_1 .= 0.0
            for (b_index, b) in enumerate(choice_set)
                to_add .= @views one_minus_initial_bank_cdf_vals[b, :]
                for (bprime_index, bprime) in enumerate(choice_set)
                    if bprime_index != b_index
                        to_add .*= @views initial_bank_cdf_vals[bprime, :]
                    end
                end
                base_1 .+= to_add
            end
            cumtrapz_1 = cumtrapz_upper(uJ_grid, base_1)

            base_1 .*= initial_home_bank_cdf_vals
            cumtrapz_1_with_home = cumtrapz_upper(uJ_grid, base_1)

            next_part = base_0
            for i in eachindex(next_part)
                good_indices[i] ? next_part[i] = uJ_grid[end] - uJ_grid[i] - cumtrapz_0[i] - (cumtrapz_1_with_home[i] - initial_home_bank_cdf_vals[i] * cumtrapz_1[i]) / (initial_home_bank_cdf_vals[end] - initial_home_bank_cdf_vals[i]) : uJ_grid[end] - uJ_grid[i] - cumtrapz_0[i] - cumtrapz_1[i]
            end


            this_search_cost::TV = sum([per_bank_search_costs[b] for b in choice_set])
            choice_set_utilities[choice_set_index, :, bias_index] .= @. next_part - this_search_cost            

        end
    end


    choice_set_probabilities::Array{TV, 3} = zeros(size(choice_set_utilities))
    outside_option_probabilities::Array{TV, 2} = zeros(length(uJ_grid), length(bias_grid))
    logit_probabilities!(choice_set_probabilities, outside_option_probabilities,
        choice_set_utilities, choice_sets_N, σinv_choice)

    benefit_of_search_uJ::Array{TV, 2} = dropdims(sum((choice_sets_N .* choice_set_probabilities) .* choice_set_utilities, dims = 1), dims = 1)
    S_base::Vector{TV} = zeros(length(bias_grid)) 

    for (bias_index, bias) in enumerate(bias_grid)
        initial_jump_bid_pdf_vals = initial_jump_bid_pdf(uJ_grid .+ bias)

        S_base[bias_index] = @views partial_cumtrapz(uJ_grid, benefit_of_search_uJ[1:(u_current_idx+1), bias_index] .* initial_jump_bid_pdf_vals[1:(u_current_idx+1)], u_current_idx, u_current) + initial_jump_bid_cdf(u_current + bias) * 
            partial_interpolate(uJ_grid, benefit_of_search_uJ[:, bias_index], u_current_idx, u_current)

    end

    pr_search::Vector{TV} = @. (exp(σinv_start_search * (S_base - initial_search_cost))) / (1.0 + (exp(σinv_start_search * (S_base - initial_search_cost)))) ### THIS ARE VECTORS OF SIZE |CHOICE SETS|=9
    pr_no_search = 1 .- pr_search


    likelihood = 1 - ρ + ρ * dot(pr_no_search, weights)

    
    return log(likelihood)

end 

function call_autodiff_example()

    spec::SpecFormat = get_spec()
    sub_param_indices::SubParamIndices = get_param_indices(spec)[2]

    uJ_grid::Vector{Float64} = union(collect(0:0.05:2.0), [2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0]) # this will be a constant 
    uJ_grid = sort!(-uJ_grid)

    base_nodes, base_weights = SVector{9}.(FastGaussQuadrature.gausshermite(9))
    weights = SVector{9}(base_weights / sqrt(π))

    params = [0.0, 0.0, 0.25, 0.0, 0.1, 1.2, 0.25, 1.2, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0, 
                1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 
                0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0]

    u_current_idx = 44
    remaining_term = 67.0
    amount = 13.32
    chosen_bank = -1
    current_monthly_payment = 0.24
    final_monthly_payment = NaN
    search_type = 1 
    choice_sets = [[1], [1, 2], [1, 2, 3], [1, 2, 4], [1, 3], [1, 3, 4], 
                    [1, 4], [1, 4, 4], [2], [2, 3], [2, 3, 4], [2, 4], [2, 4, 4], 
                    [3], [3, 4], [3, 4, 4], [4], [4, 4], [4, 4, 4]]
    choice_sets_N = [1, 1, 1, 15, 1, 15, 15, 105, 1, 1, 15, 15, 105, 1, 15, 105, 15, 105, 455]
    chosen_choice_set = 0
    dx = zeros(44)

    single_ll(x) = get_ll_single(x, sub_param_indices, uJ_grid, base_nodes, weights, u_current_idx,
    remaining_term, amount, chosen_bank, current_monthly_payment, final_monthly_payment, search_type, choice_sets, choice_sets_N, chosen_choice_set)
    autodiff(Reverse, single_ll, Active, Duplicated(params, dx))

end

call_autodiff_example()

When I run this I get an error with the Stacktrace reading:

ERROR: LoadError: Enzyme execution failed.
Mismatched activity for:   store {} addrspace(10)* %value_phi5253395, {} addrspace(10)** %.fca.1.gep, align 8, !dbg !1671, !noalias !216 const val:   %value_phi5253395 = phi {} addrspace(10)* [ %arrayref514, %L1631.lr.ph ], [ %arrayref1084, %L3207 ]
Type tree: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Integer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}
 llvalue=  %arrayref514 = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %arrayptr5112966, align 8, !dbg !1096, !tbaa !1097, !alias.scope !197, !noalias !198
You may be using a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Activity-of-temporary-storage). If not, please open an issue, and either rewrite this variable to not be conditionally active or use Enzyme.API.runtimeActivity!(true) as a workaround for now

Stacktrace:
 [1] get_ll_single
   @ ~/Documents/refinancing/multistep/example_Enzyme.jl:527

Stacktrace:
  [1] throwerr(cstr::Cstring)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:1612
  [2] get_ll_single
    @ ~/Documents/refinancing/multistep/example_Enzyme.jl:527
  [3] *
    @ ./float.jl:411 [inlined]
  [4] trapz
    @ ~/Documents/refinancing/multistep/example_Enzyme.jl:87
  [5] macro expansion
    @ ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:6587 [inlined]
  [6] enzyme_call
    @ ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:6188 [inlined]
  [7] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:6065 [inlined]
  [8] autodiff
    @ ~/.julia/packages/Enzyme/aioBJ/src/Enzyme.jl:309 [inlined]
  [9] autodiff
    @ ~/.julia/packages/Enzyme/aioBJ/src/Enzyme.jl:321 [inlined]
 [10] call_autofiff_example()
    @ Main ~/Documents/refinancing/multistep/example_Enzyme.jl:592
 [11] top-level scope
    @ ~/Documents/refinancing/multistep/example_Enzyme.jl:596
 [12] include(fname::String)
    @ Base.MainInclude ./client.jl:489
 [13] top-level scope
    @ REPL[2]:1
in expression starting at /Users/miguelborrero/Documents/refinancing/multistep/example_Enzyme.jl:596

Any advice would be greatly appreciated! Thanks a lot in advance!

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions