-
Notifications
You must be signed in to change notification settings - Fork 82
Description
Hi there,
Opening this issue to see whether I could get some help with making Reverse mode in autodiff work. I made it work on trivial examples, even containing conditional statements but for some reason it won’t work on the function I am interested. This function takes in various parameters and then a vector of floats over which we want to calculate the gradient. It is a long function so for now I will only provide a reduced example. I am sorry if this is still too long. Below you can find self-contained code which includes a function at the end that calls autodiff using Enzyme. Moreover, after the code I post the error message.
using Distributions
using Polynomials
using Combinatorics
using Statistics
using FastGaussQuadrature
using Interpolations
using LinearAlgebra
using Printf
using BenchmarkTools
using StaticArrays
using Enzyme
struct SpecFormat
special_banks::Vector{Int64} # list of banks that are special
home_bank::Int64
max_banks_in_choice_set::Int64 # max number of banks in the choice set
hermite_order::Int64
bernstein_order::Int64
initial_search_cost_X::Vector{Union{Symbol, String}}
per_bank_search_cost_X::Vector{Union{Symbol, String}}
ρ_X::Vector{Union{Symbol, String}}
γ_X::Vector{Union{Symbol, String}}
γ_std_X::Vector{Union{Symbol, String}}
κ_X::Vector{Union{Symbol, String}}
initial_mean_shifters::Vector{Union{Symbol, String}}
initial_std_shifters::Vector{Union{Symbol, String}}
final_mean_shifters::Vector{Union{Symbol, String}}
final_std_shifters::Vector{Union{Symbol, String}}
end
struct SubParamIndices
κ::Int64
γ::Int64
γ_std::Int64
initial_search_cost::Int64
ρ::Int64
initial_bank_μ::Int64
initial_bank_σ::Int64
final_bank_μ::Int64
final_bank_σ::Int64
σinv_choice::Int64
σinv_start_search::Int64
per_bank_search_cost::Vector{Int64}
bank_params::Vector{Vector{Int64}}
jump_bid_params::Vector{Int64}
implied_jump_bid_params::Vector{Int64}
n_params::Int64
end
struct ParamIndices
bank_params::Vector{Vector{Int64}}
jump_bid_params::Vector{Int64}
implied_jump_bid_params::Vector{Int64}
initial_search_cost_X::Vector{Int64}
per_bank_search_cost_X::Vector{Int64}
ρ_X::Vector{Int64}
γ_X::Vector{Int64}
γ_std_X::Vector{Int64}
κ_X::Vector{Int64}
bank_mean_shifters::Vector{Int64}
bank_std_shifters::Vector{Int64}
initial_search_cost_X_indices::Vector{Int64}
ρ_X_indices::Vector{Int64}
γ_X_indices::Vector{Int64}
γ_std_X_indices::Vector{Int64}
κ_X_indices::Vector{Int64}
initial_mean_indices::Vector{Int64}
initial_std_indices::Vector{Int64}
final_mean_indices::Vector{Int64}
final_std_indices::Vector{Int64}
σinv_choice::Int64
σinv_start_search::Int64
n_params::Int64
end
function cumtrapz_upper(x::Vector{Float64}, y::AbstractArray{TV, 1})::Vector{TV} where TV
z::Vector{TV} = zeros(Float64, length(x))
z[1] = 0.0
@inbounds @fastmath for i = 1:length(x)-1
z[i+1] = z[i] + (x[i+1] - x[i]) * (y[i+1] + y[i])
end
z = z[end] .- z
return 0.5 * z
end
function hermite5_cdf(a::AbstractArray{TV, 1}, μ::TV, σ::TV) where TV
poly_coeffs::Vector{TV} = [-48 * a[1] * a[2] - 24 * sqrt(2) * a[2] * a[3] + 8 * sqrt(6) * a[1] * a[4] - 24 * sqrt(3) * a[3] * a[4] + 4 * sqrt(6) * a[2] * a[5] - 36 * a[4] * a[5],
-24 * a[2]^2 - 24 * sqrt(2) * a[1] * a[3] - 12 * a[3]^2 - 24 * a[4]^2 + 12 * sqrt(6) * a[1] * a[5] - 12 * sqrt(3) * a[3] * a[5] - 15 * a[5]^2,
-24 * sqrt(2) * a[2] * a[3] - 8 * sqrt(6) * a[1] * a[4] + 8 * sqrt(6) * a[2] * a[5] - 36 * a[4] * a[5],
-12 * a[3]^2 - 8 * sqrt(6) * a[2] * a[4] + 4 * a[4]^2 - 4 * sqrt(6) * a[1] * a[5] + 8 * sqrt(3) * a[3] * a[5] - 17 * a[5]^2,
-8 * sqrt(3) * a[3] * a[4] - 4 * sqrt(6) * a[2] * a[5] + 12 * a[4] * a[5],
-4 * a[4]^2 - 4 * sqrt(3) * a[3] * a[5] + 5 * a[5]^2,
-4 * a[4] * a[5],
-a[5]^2]
func(x) = σ * Polynomial(poly_coeffs).((x .- μ) / σ) .* pdf.(Normal(μ, σ), x) ./ 24.0 .+ cdf.(Normal(μ, σ), x)
return func
end
function hermite5_pdf(a::AbstractArray{TV, 1}, μ::TV, σ::TV) where TV
poly_coeffs::Vector{TV} = [a[1] - a[3] / sqrt(2) + sqrt(3/8) * a[5],
a[2] - sqrt(3/2) * a[4],
a[3] / sqrt(2) - sqrt(3/2) * a[5],
a[4] / sqrt(6),
a[5] / sqrt(24)]
func(x) = (Polynomial(poly_coeffs).((x .- μ) / σ)).^2 .* pdf.(Normal(μ, σ), x)
return func
end
function bernstein3(a::AbstractArray{TV, 1}; upper::Float64 = 1.0, lower::Float64 = 0.0) where TV
base_func(x) = a[1] * (1 - x)^3 + 3 * a[2] * x * (1 - x)^2 + 3 * a[3] * x^2 * (1 - x) + a[4] * x^3
func(x) = base_func((x - lower) / (upper - lower))
return func
end
function solve_cubic_eq(poly::AbstractVector{Complex{T}}) where T
# Cubic equation solver for complex polynomial (degree=3)
# http://en.wikipedia.org/wiki/Cubic_function Lagrange's method
a1 = 1 / poly[4]
E1 = -poly[3]*a1
E2 = poly[2]*a1
E3 = -poly[1]*a1
s0 = E1
E12 = E1*E1
A = 2*E1*E12 - 9*E1*E2 + 27*E3 # = s1^3 + s2^3
B = E12 - 3*E2 # = s1 s2
# quadratic equation: z^2 - Az + B^3=0 where roots are equal to s1^3 and s2^3
Δ = sqrt(A*A - 4*B*B*B)
if real(conj(A)*Δ)>=0 # scalar product to decide the sign yielding bigger magnitude
s1 = exp(log(0.5 * (A + Δ)) * (1/3))
else
s1 = exp(log(0.5 * (A - Δ)) * (1/3))
end
if s1 == 0
s2 = s1
else
s2 = B / s1
end
zeta1 = complex(-0.5, sqrt(T(3.0))*0.5)
zeta2 = conj(zeta1)
# return third*(s0 + s1 + s2), third*(s0 + s1*zeta2 + s2*zeta1), third*(s0 + s1*zeta1 + s2*zeta2)
sol1 = (1/3) * (s0 + s1 + s2)
sol2 = (1/3) * (s0 + s1 * zeta2 + s2 * zeta1)
sol3 = (1/3) * (s0 + s1 * zeta1 + s2 * zeta2)
if abs(imag(sol1)) < 1e-8 && real(sol1) > 0.0 && real(sol1) < 1.0
return real(sol1)
elseif abs(imag(sol2)) < 1e-8 && real(sol2) > 0.0 && real(sol2) < 1.0
return real(sol2)
elseif abs(imag(sol3)) < 1e-8 && real(sol3) > 0.0 && real(sol3) < 1.0
return real(sol3)
else
return NaN
end
end
function bernstein3_inv(p::AbstractArray{TV, 1}, val::Float64; upper::Float64 = 1.0, lower::Float64 = 0.0) where TV
# Get this in the form ax^3 + bx^2 + cx + d = 0
a = p[4] + 3 * p[2] - 3 * p[3] - p[1]
b = 3 * p[1] - 6 * p[2] + 3 * p[3]
c = 3 * p[2] - 3 * p[1]
d = p[1] - val
coeffs = [d, c, b, a]
if d >= 0.0 # (poly(0.0))
return lower
elseif a + b + c + d <= 0.0 # (poly(1.0))
return upper
end
x = solve_cubic_eq(Complex.(coeffs))
return x * (upper - lower) + lower
end
# From PolynomialRoots.jl
function solve_cubic_eq(poly::AbstractVector{Complex{T}}) where T
# Cubic equation solver for complex polynomial (degree=3)
# http://en.wikipedia.org/wiki/Cubic_function Lagrange's method
a1 = 1 / poly[4]
E1 = -poly[3]*a1
E2 = poly[2]*a1
E3 = -poly[1]*a1
s0 = E1
E12 = E1*E1
A = 2*E1*E12 - 9*E1*E2 + 27*E3 # = s1^3 + s2^3
B = E12 - 3*E2 # = s1 s2
# quadratic equation: z^2 - Az + B^3=0 where roots are equal to s1^3 and s2^3
Δ = sqrt(A*A - 4*B*B*B)
if real(conj(A)*Δ)>=0 # scalar product to decide the sign yielding bigger magnitude
s1 = exp(log(0.5 * (A + Δ)) * (1/3))
else
s1 = exp(log(0.5 * (A - Δ)) * (1/3))
end
if s1 == 0
s2 = s1
else
s2 = B / s1
end
zeta1 = complex(-0.5, sqrt(T(3.0))*0.5)
zeta2 = conj(zeta1)
# return third*(s0 + s1 + s2), third*(s0 + s1*zeta2 + s2*zeta1), third*(s0 + s1*zeta1 + s2*zeta2)
sol1 = (1/3) * (s0 + s1 + s2)
sol2 = (1/3) * (s0 + s1 * zeta2 + s2 * zeta1)
sol3 = (1/3) * (s0 + s1 * zeta1 + s2 * zeta2)
if abs(imag(sol1)) < 1e-8 && real(sol1) > 0.0 && real(sol1) < 1.0
return real(sol1)
elseif abs(imag(sol2)) < 1e-8 && real(sol2) > 0.0 && real(sol2) < 1.0
return real(sol2)
elseif abs(imag(sol3)) < 1e-8 && real(sol3) > 0.0 && real(sol3) < 1.0
return real(sol3)
else
return NaN
end
end
function bernstein3_inv(p::AbstractArray{TV, 1}, val::Float64; upper::Float64 = 1.0, lower::Float64 = 0.0) where TV
# Get this in the form ax^3 + bx^2 + cx + d = 0
a = p[4] + 3 * p[2] - 3 * p[3] - p[1]
b = 3 * p[1] - 6 * p[2] + 3 * p[3]
c = 3 * p[2] - 3 * p[1]
d = p[1] - val
coeffs = [d, c, b, a]
if d >= 0.0 # (poly(0.0))
return lower
elseif a + b + c + d <= 0.0 # (poly(1.0))
return upper
end
x = solve_cubic_eq(Complex.(coeffs))
return x * (upper - lower) + lower
end
# Here, we take in a set of choice sets, each of with can be repeated N times.
function logit_probabilities!(inside_probabilities::Array{TV, 3}, outside_probabilities::Matrix{TV}, utilities::Array{TV, 3}, N::Vector{Int64}, σinv::TV) where TV
for uJ_index in axes(utilities, 2)
for bias_index in axes(utilities, 3)
this_exp_probs = @views exp.(utilities[:, uJ_index, bias_index] * σinv)
inv_denom = 1.0 / (1.0 + dot(this_exp_probs, N))
inside_probabilities[:, uJ_index, bias_index] .= inv_denom * this_exp_probs
outside_probabilities[uJ_index, bias_index] = inv_denom
end
end
return nothing
end
function partial_interpolate(x::AbstractArray{Float64, 1}, y::AbstractArray{TV, 1}, idx::Int64, val::Float64) where TV
z::TV = 0.0
if idx == 0
z = x[1]
elseif idx == length(x)
z = x[end]
else
z = ((y[idx+1] - y[idx]) * (val - x[idx]) / (x[idx+1] - x[idx])) + y[idx]
end
return z
end
function get_spec()::SpecFormat
special_banks::Vector{Int64} = [1, 14, 16]
home_bank::Int64 = 12
max_banks_in_choice_set::Int64 = 3
hermite_order::Int64 = 5
bernstein_order::Int64 = 4
initial_search_cost_X::Vector{Union{Symbol, String}} = [:ones, :x_educ_high, :x_income]
per_bank_search_cost_X::Vector{Union{Symbol, String}} = [:ones, :x_educ_high, :x_income, :s_n_branches]
ρ_X::Vector{Union{Symbol, String}} = [:ones, :x_educ_high, :x_income]
γ_X::Vector{Union{Symbol, String}} = [:ones, :x_educ_high, :x_income]
γ_std_X::Vector{Union{Symbol, String}} = [:ones]
κ_X::Vector{Union{Symbol, String}} = [:ones, :x_educ_high, :x_income]
initial_mean_shifters::Vector{Union{Symbol, String}} = [:ones, :x_educ_high, :x_income, :b_term, :b_amount]
initial_std_shifters::Vector{Union{Symbol, String}} = [:ones, :b_term, :b_amount]
final_mean_shifters::Vector{Union{Symbol, String}} = [:ones, :x_educ_high, :x_income, :b_term, :b_amount]
final_std_shifters::Vector{Union{Symbol, String}} = [:ones, :b_term, :b_amount]
spec::SpecFormat = SpecFormat(special_banks, home_bank, max_banks_in_choice_set, hermite_order, bernstein_order,
initial_search_cost_X, per_bank_search_cost_X, ρ_X, γ_X, γ_std_X, κ_X, initial_mean_shifters, initial_std_shifters,
final_mean_shifters, final_std_shifters)
return spec
end
function get_indices(grid::AbstractArray{Float64, 1}, vals::AbstractArray{Float64, 1})::Vector{Int64}
indices = zeros(Int64, length(vals))
for (i, val) in enumerate(vals)
indices[i] = searchsortedlast(grid, val)
end
return indices
end
function get_param_indices(spec::SpecFormat)::Tuple{ParamIndices, SubParamIndices}
param_val::Int64 = 0
bank_params::Vector{Vector{Int64}} = []
for b = 1:(length(spec.special_banks)+1)
this_bank_params::Vector{Int64}, param_val = param_update(spec.hermite_order, param_val)
push!(bank_params, this_bank_params)
end
jump_bid_params::Vector{Int64}, param_val = param_update(spec.hermite_order, param_val)
implied_jump_bid_params::Vector{Int64}, param_val = param_update(spec.bernstein_order, param_val)
initial_search_cost_X::Vector{Int64}, param_val = param_update(length(spec.initial_search_cost_X), param_val)
per_bank_search_cost_X::Vector{Int64}, param_val = param_update(length(spec.per_bank_search_cost_X), param_val)
ρ_X::Vector{Int64}, param_val = param_update(length(spec.ρ_X), param_val)
γ_X::Vector{Int64}, param_val = param_update(length(spec.γ_X), param_val)
γ_std_X::Vector{Int64}, param_val = param_update(length(spec.γ_std_X), param_val)
κ_X::Vector{Int64}, param_val = param_update(length(spec.κ_X), param_val)
@assert(length(spec.initial_mean_shifters) == length(spec.final_mean_shifters))
@assert(length(spec.initial_std_shifters) == length(spec.final_std_shifters))
bank_mean_shifters::Vector{Int64}, param_val = param_update(length(spec.initial_mean_shifters), param_val)
bank_std_shifters::Vector{Int64}, param_val = param_update(length(spec.initial_std_shifters), param_val)
σinv_choice::Int64, param_val = param_update(1, param_val; return_as_vector = false)
σinv_start_search::Int64, param_val = param_update(1, param_val; return_as_vector = false)
# Now get the inidices: Will have one X matrix at the individual level (N_individuals x total parameters)
# and then indices for which subsets of this X to use for various parameters
all_X = sort!(union(spec.initial_search_cost_X, spec.ρ_X, spec.γ_X, spec.γ_std_X, spec.κ_X,
spec.initial_mean_shifters, spec.initial_std_shifters, spec.final_mean_shifters, spec.final_std_shifters)) # per-bank search cost will be at the individual-bank level
initial_search_cost_X_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.initial_search_cost_X]
ρ_X_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.ρ_X]
γ_X_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.γ_X]
γ_std_X_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.γ_std_X]
κ_X_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.κ_X]
initial_mean_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.initial_mean_shifters]
initial_std_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.initial_std_shifters]
final_mean_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.final_mean_shifters]
final_std_indices::Vector{Int64} = [findall(x -> x == i, all_X)[1] for i in spec.final_std_shifters]
param_indices = ParamIndices(bank_params, jump_bid_params, implied_jump_bid_params,
initial_search_cost_X, per_bank_search_cost_X, ρ_X, γ_X, γ_std_X, κ_X,
bank_mean_shifters, bank_std_shifters,
initial_search_cost_X_indices, ρ_X_indices, γ_X_indices, γ_std_X_indices, κ_X_indices,
initial_mean_indices, initial_std_indices, final_mean_indices, final_std_indices,
σinv_choice, σinv_start_search, param_val)
sub_param_val::Int64 = 0
sub_κ::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
sub_γ::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
sub_γ_std::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
sub_initial_search_cost::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
sub_ρ::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
sub_initial_bank_μ::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
sub_initial_bank_σ::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
sub_final_bank_μ::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
sub_final_bank_σ::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
sub_σinv_choice::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
sub_σinv_start_search::Int64, sub_param_val = param_update(1, sub_param_val; return_as_vector = false)
sub_per_bank_search_cost::Vector{Int64}, sub_param_val = param_update(length(spec.special_banks)+1, sub_param_val)
sub_bank_params::Vector{Vector{Int64}} = []
for b = 1:(length(spec.special_banks)+1)
this_bank_params::Vector{Int64}, sub_param_val = param_update(spec.hermite_order, sub_param_val)
push!(sub_bank_params, this_bank_params)
end
sub_jump_bid_params::Vector{Int64}, sub_param_val = param_update(spec.hermite_order, sub_param_val)
sub_implied_jump_bid_params::Vector{Int64}, sub_param_val = param_update(spec.bernstein_order, sub_param_val)
sub_param_indices = SubParamIndices(sub_κ, sub_γ, sub_γ_std, sub_initial_search_cost, sub_ρ,
sub_initial_bank_μ, sub_initial_bank_σ, sub_final_bank_μ, sub_final_bank_σ,
sub_σinv_choice, sub_σinv_start_search, sub_per_bank_search_cost, sub_bank_params,
sub_jump_bid_params, sub_implied_jump_bid_params, sub_param_val)
return (param_indices, sub_param_indices)
end
function param_update(N::Int64, param_val::Int64; return_as_vector::Bool = true)
if N > 1 || return_as_vector
indices::Vector{Int64} = collect(1:N) .+ param_val
param_val += N
return indices, param_val
else
indices_int::Int64 = param_val + 1
param_val += 1
return indices_int, param_val
end
end
function get_ll_single(params::AbstractArray{TV, 1}, param_indices::SubParamIndices, uJ_grid::Vector{Float64}, base_nodes::SVector{9, Float64}, weights::SVector{9, Float64}
, u_current_idx::Int64, remaining_term::Float64, amount::Float64, chosen_bank::Int64, current_monthly_payment::Float64, final_monthly_payment::Float64,
search_type::Int64, choice_sets::Vector{Vector{Int64}}, choice_sets_N::Vector{Int64}, chosen_choice_set::Int64) where TV
# Initialize the values of the frictions -- this applies for everyone
κ::TV = params[param_indices.κ]
γ::TV = params[param_indices.γ]
γ_std::TV = params[param_indices.γ_std]
initial_search_cost::TV = params[param_indices.initial_search_cost]
ρ::TV = params[param_indices.ρ]
initial_bank_μ::TV = params[param_indices.initial_bank_μ]
initial_bank_σ::TV = params[param_indices.initial_bank_σ]
final_bank_μ::TV = params[param_indices.final_bank_μ]
final_bank_σ::TV = params[param_indices.final_bank_σ]
σinv_choice::TV = params[param_indices.σinv_choice]
σinv_start_search::TV = params[param_indices.σinv_start_search]
## Now the loop begins
# Note that the cdf of zero-profit rates is written in terms of monthly payments if A = 1, scaled by number of
# payments. So, a value of m corresponds to a monthly payment of m * A / T.
#
# We know that monthly utility = (delta - m * A / T) - kappa / NPVrate(T),
# So, Pr(utility <= u) = Pr((delta - m * A / T) - kappa/NPV <= u) = Pr(scaled monthly payment >= (T/A) * (delta - (u + kappa/NPV)))
# The bias is in terms of monthly payments. So, we just add the bias.
β_customer_m::Float64 = 0.95^(1/12)
per_bank_search_costs::Vector{TV} = params[param_indices.per_bank_search_cost]
npv_scale::Float64 = (1.0 - β_customer_m^remaining_term) / (1 - β_customer_m)
overall_scale::Float64 = remaining_term / amount
bias_grid = base_nodes * γ_std * sqrt(2) .+ γ # in units of dollars
scaled_κ::TV = κ / npv_scale # in units of dollars per month
# Generate the cdfs and pdfs for the banks: cdfs at the time of search, and both later if needed
# Maps utilities ($) to utilities ($) <--- could rethink whether this is the best map
implied_jump_bid::Function = bernstein3(params[param_indices.implied_jump_bid_params]; upper = maximum(uJ_grid) + 2.0, lower = minimum(uJ_grid) - 2.0)
implied_jump_bid_inv(val::Float64) = bernstein3_inv(params[param_indices.implied_jump_bid_params], val; upper = maximum(uJ_grid) + 2.0, lower = minimum(uJ_grid) - 2.0)
initial_bank_cdf = [x -> 1.0 .- hermite5_cdf(params[a], initial_bank_μ, initial_bank_σ)(-overall_scale .* (x .+ scaled_κ)) for a in param_indices.bank_params]
initial_jump_bid_cdf::Function = x -> 1.0 .- hermite5_cdf(params[param_indices.jump_bid_params], initial_bank_μ, initial_bank_σ)(-overall_scale .* x)
initial_jump_bid_pdf::Function = x -> overall_scale .* hermite5_pdf(params[param_indices.jump_bid_params], initial_bank_μ, initial_bank_σ)(-overall_scale .* x)
initial_home_bank_cdf(x) = initial_jump_bid_cdf(implied_jump_bid.(x))
final_bank_cdf = [x -> 1.0 .- hermite5_cdf(params[a], final_bank_μ, final_bank_σ)(@. -overall_scale * (x + scaled_κ)) for a in param_indices.bank_params]
final_bank_pdf = [x -> overall_scale .* hermite5_pdf(params[a], final_bank_μ, final_bank_σ)(@. -overall_scale * (x + scaled_κ)) for a in param_indices.bank_params]
final_jump_bid_cdf::Function = x -> 1.0 .- hermite5_cdf(params[param_indices.jump_bid_params], final_bank_μ, final_bank_σ)(@. -overall_scale * x)
final_jump_bid_pdf::Function = x -> overall_scale .* hermite5_pdf(params[param_indices.jump_bid_params], final_bank_μ, final_bank_σ)(@. -overall_scale * x)
final_home_bank_cdf(x) = final_jump_bid_cdf(implied_jump_bid.(x))
refi_at_home_bank::Bool = chosen_bank == 0
u_current::Float64 = -current_monthly_payment
u_final::TV = -final_monthly_payment - scaled_κ * refi_at_home_bank
choice_set_utilities::Array{TV, 3} = zeros(length(choice_sets), length(uJ_grid), length(bias_grid)) ### CONTAINER 3D-FLOAT
base_0::Vector{TV} = zeros(length(uJ_grid)) ### CONTAINER 1D FLOAT SIZE UJ_GRID
base_1::Vector{TV} = zeros(length(uJ_grid)) ### CONTAINER 1D FLOAT SIZE UJ_GRID
to_add::Vector{TV} = zeros(length(uJ_grid)) ### CONTAINER 1D FLOAT SIZE UJ_GRID
initial_bank_cdf_vals::Matrix{Float64} = zeros(Float64, length(initial_bank_cdf), length(uJ_grid))
one_minus_initial_bank_cdf_vals::Matrix{TV} = zeros(Float64, length(initial_bank_cdf), length(uJ_grid))
@inbounds for (bias_index, bias) in enumerate(bias_grid)
biased_uJ = uJ_grid .+ bias
for b = 1:length(initial_bank_cdf)
initial_bank_cdf_vals[b, :] .= initial_bank_cdf[b](biased_uJ)
one_minus_initial_bank_cdf_vals[b, :] .= @views 1.0 .- initial_bank_cdf_vals[b, :]
end
initial_home_bank_cdf_vals::Vector{Float64} = initial_home_bank_cdf(biased_uJ)
good_indices::BitVector = (initial_home_bank_cdf_vals[end] .- initial_home_bank_cdf_vals) .>= 1e-7
@inbounds for (choice_set_index, choice_set) in enumerate(choice_sets)
base_0 .= @views initial_bank_cdf_vals[choice_set[1], :]
for (b_idx, b) in enumerate(choice_set)
b_idx > 1 || continue
base_0 .*= @views initial_bank_cdf_vals[b, :]
end
cumtrapz_0 = cumtrapz_upper(uJ_grid, base_0)
base_1 .= 0.0
for (b_index, b) in enumerate(choice_set)
to_add .= @views one_minus_initial_bank_cdf_vals[b, :]
for (bprime_index, bprime) in enumerate(choice_set)
if bprime_index != b_index
to_add .*= @views initial_bank_cdf_vals[bprime, :]
end
end
base_1 .+= to_add
end
cumtrapz_1 = cumtrapz_upper(uJ_grid, base_1)
base_1 .*= initial_home_bank_cdf_vals
cumtrapz_1_with_home = cumtrapz_upper(uJ_grid, base_1)
next_part = base_0
for i in eachindex(next_part)
good_indices[i] ? next_part[i] = uJ_grid[end] - uJ_grid[i] - cumtrapz_0[i] - (cumtrapz_1_with_home[i] - initial_home_bank_cdf_vals[i] * cumtrapz_1[i]) / (initial_home_bank_cdf_vals[end] - initial_home_bank_cdf_vals[i]) : uJ_grid[end] - uJ_grid[i] - cumtrapz_0[i] - cumtrapz_1[i]
end
this_search_cost::TV = sum([per_bank_search_costs[b] for b in choice_set])
choice_set_utilities[choice_set_index, :, bias_index] .= @. next_part - this_search_cost
end
end
choice_set_probabilities::Array{TV, 3} = zeros(size(choice_set_utilities))
outside_option_probabilities::Array{TV, 2} = zeros(length(uJ_grid), length(bias_grid))
logit_probabilities!(choice_set_probabilities, outside_option_probabilities,
choice_set_utilities, choice_sets_N, σinv_choice)
benefit_of_search_uJ::Array{TV, 2} = dropdims(sum((choice_sets_N .* choice_set_probabilities) .* choice_set_utilities, dims = 1), dims = 1)
S_base::Vector{TV} = zeros(length(bias_grid))
for (bias_index, bias) in enumerate(bias_grid)
initial_jump_bid_pdf_vals = initial_jump_bid_pdf(uJ_grid .+ bias)
S_base[bias_index] = @views partial_cumtrapz(uJ_grid, benefit_of_search_uJ[1:(u_current_idx+1), bias_index] .* initial_jump_bid_pdf_vals[1:(u_current_idx+1)], u_current_idx, u_current) + initial_jump_bid_cdf(u_current + bias) *
partial_interpolate(uJ_grid, benefit_of_search_uJ[:, bias_index], u_current_idx, u_current)
end
pr_search::Vector{TV} = @. (exp(σinv_start_search * (S_base - initial_search_cost))) / (1.0 + (exp(σinv_start_search * (S_base - initial_search_cost)))) ### THIS ARE VECTORS OF SIZE |CHOICE SETS|=9
pr_no_search = 1 .- pr_search
likelihood = 1 - ρ + ρ * dot(pr_no_search, weights)
return log(likelihood)
end
function call_autodiff_example()
spec::SpecFormat = get_spec()
sub_param_indices::SubParamIndices = get_param_indices(spec)[2]
uJ_grid::Vector{Float64} = union(collect(0:0.05:2.0), [2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0]) # this will be a constant
uJ_grid = sort!(-uJ_grid)
base_nodes, base_weights = SVector{9}.(FastGaussQuadrature.gausshermite(9))
weights = SVector{9}(base_weights / sqrt(π))
params = [0.0, 0.0, 0.25, 0.0, 0.1, 1.2, 0.25, 1.2, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0,
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0]
u_current_idx = 44
remaining_term = 67.0
amount = 13.32
chosen_bank = -1
current_monthly_payment = 0.24
final_monthly_payment = NaN
search_type = 1
choice_sets = [[1], [1, 2], [1, 2, 3], [1, 2, 4], [1, 3], [1, 3, 4],
[1, 4], [1, 4, 4], [2], [2, 3], [2, 3, 4], [2, 4], [2, 4, 4],
[3], [3, 4], [3, 4, 4], [4], [4, 4], [4, 4, 4]]
choice_sets_N = [1, 1, 1, 15, 1, 15, 15, 105, 1, 1, 15, 15, 105, 1, 15, 105, 15, 105, 455]
chosen_choice_set = 0
dx = zeros(44)
single_ll(x) = get_ll_single(x, sub_param_indices, uJ_grid, base_nodes, weights, u_current_idx,
remaining_term, amount, chosen_bank, current_monthly_payment, final_monthly_payment, search_type, choice_sets, choice_sets_N, chosen_choice_set)
autodiff(Reverse, single_ll, Active, Duplicated(params, dx))
end
call_autodiff_example()When I run this I get an error with the Stacktrace reading:
ERROR: LoadError: Enzyme execution failed.
Mismatched activity for: store {} addrspace(10)* %value_phi5253395, {} addrspace(10)** %.fca.1.gep, align 8, !dbg !1671, !noalias !216 const val: %value_phi5253395 = phi {} addrspace(10)* [ %arrayref514, %L1631.lr.ph ], [ %arrayref1084, %L3207 ]
Type tree: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Integer, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}
llvalue= %arrayref514 = load {} addrspace(10)*, {} addrspace(10)* addrspace(13)* %arrayptr5112966, align 8, !dbg !1096, !tbaa !1097, !alias.scope !197, !noalias !198
You may be using a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Activity-of-temporary-storage). If not, please open an issue, and either rewrite this variable to not be conditionally active or use Enzyme.API.runtimeActivity!(true) as a workaround for now
Stacktrace:
[1] get_ll_single
@ ~/Documents/refinancing/multistep/example_Enzyme.jl:527
Stacktrace:
[1] throwerr(cstr::Cstring)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:1612
[2] get_ll_single
@ ~/Documents/refinancing/multistep/example_Enzyme.jl:527
[3] *
@ ./float.jl:411 [inlined]
[4] trapz
@ ~/Documents/refinancing/multistep/example_Enzyme.jl:87
[5] macro expansion
@ ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:6587 [inlined]
[6] enzyme_call
@ ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:6188 [inlined]
[7] CombinedAdjointThunk
@ ~/.julia/packages/Enzyme/aioBJ/src/compiler.jl:6065 [inlined]
[8] autodiff
@ ~/.julia/packages/Enzyme/aioBJ/src/Enzyme.jl:309 [inlined]
[9] autodiff
@ ~/.julia/packages/Enzyme/aioBJ/src/Enzyme.jl:321 [inlined]
[10] call_autofiff_example()
@ Main ~/Documents/refinancing/multistep/example_Enzyme.jl:592
[11] top-level scope
@ ~/Documents/refinancing/multistep/example_Enzyme.jl:596
[12] include(fname::String)
@ Base.MainInclude ./client.jl:489
[13] top-level scope
@ REPL[2]:1
in expression starting at /Users/miguelborrero/Documents/refinancing/multistep/example_Enzyme.jl:596
Any advice would be greatly appreciated! Thanks a lot in advance!