Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type unstable gradients (@code_warntype) #1476

Closed
leonardogalliano opened this issue Dec 8, 2023 · 1 comment
Closed

Type unstable gradients (@code_warntype) #1476

leonardogalliano opened this issue Dec 8, 2023 · 1 comment

Comments

@leonardogalliano
Copy link

Hi,
I've noticed some type instability when trying to use this (very nice) package. Can you please help me understand why?

Here’s a simple example: I define a model called GaussianModel which takes a single parameter as input. This model can also be evaluated to return the value of its parameter. Then, I'd like to differentiate the function logprob, which takes the model as input using implicit gradients.

using LinearAlgebra, Random, Zygote

abstract type Model end

mutable struct GaussianModel{T<:AbstractFloat} <: Model
    σ::T
end

(model::GaussianModel)() = model.σ

function logprob(x::T, y::T, model::GaussianModel{T}) where {T<:AbstractFloat}
    σ = model()
    return -(y - x)^2 / (2σ^2) - log(2 * π * σ^2) / 2
end

Now, although this

x = 0.2
y = 0.1
mymodel = GaussianModel(0.5)
logprob(x, y, mymodel)
gradient(model -> logprob(x, y, model), mymodel)

gives the right result, this

@code_warntype gradient(model -> logprob(x, y, model), mymodel)

gives the following type instability

MethodInstance for Zygote.gradient(::var"#129#130", ::GaussianModel{Float64})
  from gradient(f, args...) in Zygote at [...]
Arguments
  #self#e[36m::Core.Const(Zygote.gradient)e[39m
  fe[36m::Core.Const(var"#129#130"())e[39m
  argse[36m::Tuple{GaussianModel{Float64}}e[39m
Locals
  @_4e[36m::Int64e[39m
  grade[33me[1m::Union{Nothing, Tuple}e[22me[39m
  backe[91me[1m::Zygote.var"#75#76"e[22me[39m
  ye[91me[1m::Anye[22me[39m
Bodye[33me[1m::Union{Nothing, Tuple{Any}}e[22me[39m
e[90m1 ─e[39m %1  = Core.tuple(f)e[36m::Core.Const((var"#129#130"(),))e[39m
e[90m│  e[39m %2  = Core._apply_iterate(Base.iterate, Zygote.pullback, %1, args)e[91me[1m::Tuple{Any, Zygote.var"#75#76"}e[22me[39m
e[90m│  e[39m %3  = Base.indexed_iterate(%2, 1)e[36m::Core.PartialStruct(Tuple{Any, Int64}, Any[Any, Core.Const(2)])e[39m
e[90m│  e[39m       (y = Core.getfield(%3, 1))
e[90m│  e[39m       (@_4 = Core.getfield(%3, 2))
e[90m│  e[39m %6  = Base.indexed_iterate(%2, 2, @_4::Core.Const(2))e[36m::Core.PartialStruct(Tuple{Zygote.var"#75#76", Int64}, Any[Zygote.var"#75#76", Core.Const(3)])e[39m
e[90m│  e[39m       (back = Core.getfield(%6, 1))
e[90m│  e[39m %8  = Zygote.sensitivity(y)e[91me[1m::Anye[22me[39m
e[90m│  e[39m       (grad = (back)(%8))
e[90m│  e[39m %10 = Zygote.isnothing(grad)e[36m::Boole[39m
e[90m└──e[39m       goto #3 if not %10
e[90m2 ─e[39m       return Zygote.nothing
e[90m3 ─e[39m %13 = Zygote.map(Zygote._project, args, grad::Tuple)e[91me[1m::Tuple{Any}e[22me[39m
e[90m└──e[39m       return %13
@ToucheSir
Copy link
Member

Since this was cross-posted from https://discourse.julialang.org/t/type-unstable-gradients-in-zygote-code-warntype/107279, closing because thread has been answered and to keep the issue tracker clean.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants