You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The following code worked with Enzyme 0.11.20 (julia 1.10.2) but fails with two different errors on Enzyme 0.12.0 (and main).
Apologies for the not so minimal MWE.
Setup
using Pkg
Pkg.activate(temp =true)
Pkg.add(["Distributions", "ComponentArrays", "ConcreteStructs"])
# Pkg.add(name = "Enzyme", version = "0.11.20")# Pkg.add(name = "Enzyme", version = "0.12.0")
Pkg.add(name ="Enzyme", rev ="main")
using Distributions, Enzyme, ComponentArrays, ConcreteStructs
struct BiasedCoin endparameters(::BiasedCoin) =ComponentArray(ρ = .5)
functionlogp(data, ::BiasedCoin, parameters)
logp =0.
ρ = parameters.ρ
for d in data
logp +=logpdf(Bernoulli(ρ), d)
end
logp
endsigmoid(w) =1/(1+exp(-w))
@concretestruct HabituatingBiasedCoin
w
endHabituatingBiasedCoin() =HabituatingBiasedCoin(Base.RefValue(0.))
functioninitialize!(m::HabituatingBiasedCoin, parameters)
m.w[] = parameters.w₀
endparameters(::HabituatingBiasedCoin) =ComponentArray(w₀ =0., η =0.)
functionlogp(data, m::HabituatingBiasedCoin, parameters)
initialize!(m, parameters)
η = parameters.η
logp =0.for d in data
ρ =sigmoid(m.w[])
logp +=logpdf(Bernoulli(ρ), d)
m.w[] += η * (d - ρ)
end
logp
end@concrete terse struct HessLogP{T}
logp
dlogp
data
ddata
model::T
dmodel
parameters
dparameters
endfunctionHessLogP(data, model, parameters =ComponentArray(parameters(model)))
dd =_zero(data)
_data =convert(typeof(dd), data)
n =length(parameters)
vx =ntuple(i ->begin t =_zero(parameters); t[i] =1; t end, n)
ddx =ntuple(_ ->_zero(parameters), n)
HessLogP(Const([0.]), Const([1.]),
BatchDuplicated(_data, ntuple(_ ->_zero(dd), n)), BatchDuplicated(_zero(dd), ntuple(_ ->_zero(dd), n)),
_deep_ismutable(model) ?BatchDuplicated(model, ntuple(_ ->_zero(model), n)) :Const(model),
_deep_ismutable(model) ?BatchDuplicated(_zero(model), ntuple(_ ->_zero(model), n)) :Const(model),
BatchDuplicated(parameters, vx), BatchDuplicated(_zero(parameters), ddx))
endfunctionlogp!(_logp, data, model, parameters)
_logp[] =logp(data, model, parameters)
nothingendfunction_grad1(logp, dlogp, data, ddata, model, params, dparams)
autodiff_deferred(Reverse, logp!,
DuplicatedNoNeed(logp, dlogp),
DuplicatedNoNeed(data, ddata),
model,
Duplicated(params, dparams))
nothingendfunction_grad2(logp, dlogp, data, ddata, model, dmodel, params, dparams)
autodiff_deferred(Reverse, logp!,
DuplicatedNoNeed(logp, dlogp),
DuplicatedNoNeed(data, ddata),
DuplicatedNoNeed(model, dmodel),
Duplicated(params, dparams))
nothingendfunction_hess(h::HessLogP{<:Const})
autodiff(Forward, _grad1, h.logp, h.dlogp, h.data, h.ddata,
h.model, h.parameters, h.dparameters)
endfunction_hess(h::HessLogP)
autodiff(Forward, _grad2, h.logp, h.dlogp, h.data, h.ddata,
h.model, h.dmodel, h.parameters, h.dparameters)
end_deep_ismutable(x::T) where T =_deep_ismutable(T)
function_deep_ismutable(x::Type)
ismutabletype(x) &&returntrueany(_deep_ismutable, fieldtypes(x))
end_zero(x::AbstractArray) =_zero.(x)
_zero(x::Number) =zero(x)
_zero(x::Base.RefValue) =Ref(_zero(x[]))
_zero(x::Tuple) =_zero.(x)
_zero(x::NamedTuple{K}) where K =NamedTuple{K}(_zero.(values(x)))
function_zero(x::D) where D
D.name.wrapper((_zero(getfield(x, f)) for f infieldnames(D))...)
endfunction (h::HessLogP)(ddx, x)
h.parameters.val .= x
h.dlogp.val .=1_hess(h)
for (i, v) inpairs(h.dparameters.dval)
ddx[:, i] .= v
v .=0endend
First Error
julia> m =BiasedCoin();
julia> p =parameters(m);
julia> h! =HessLogP(rand(0:1, 10), m);
julia> H =zeros(length(p), length(p));
julia>h!(H, p)
ERROR: Function to differentiate `MethodInstance for _grad1(::Vector{Float64}, ::Vector{Float64}, ::Vector{Int64}, ::Vector{Int64}, ::BiasedCoin, ::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(ρ = 1,)}}}, ::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(ρ = 1,)}}})` is guaranteed to return an error and doesn't make sense to autodiff. Giving up
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] macro expansion
@ ~/.julia/packages/Enzyme/y2lVn/src/compiler.jl:5718 [inlined]
[3] macro expansion
@ ./none:0 [inlined]
[4] thunk(::Val{…}, ::Type{…}, ::Type{…}, tt::Type{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Type{…})
@ Enzyme.Compiler ./none:0
[5] autodiff
@ ~/.julia/packages/Enzyme/y2lVn/src/Enzyme.jl:382 [inlined]
[6] autodiff
@ ~/.julia/packages/Enzyme/y2lVn/src/Enzyme.jl:300 [inlined]
[7] autodiff
@ ~/.julia/packages/Enzyme/y2lVn/src/Enzyme.jl:284 [inlined]
[8] _hess(h::HessLogP{Const{BiasedCoin}})
@ Main ./REPL[22]:2
[9] (::HessLogP{Const{BiasedCoin}})(ddx::Matrix{Float64}, x::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{…}}})
@ Main ./REPL[32]:4
[10] top-level scope
@ REPL[54]:1
Some type information was truncated. Use `show(err)` to see complete types.
The following code worked with Enzyme 0.11.20 (julia 1.10.2) but fails with two different errors on Enzyme 0.12.0 (and main).
Apologies for the not so minimal MWE.
Setup
First Error
Second Error
The text was updated successfully, but these errors were encountered: