Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Setup tests for driving with Enzyme #905

Closed
wants to merge 14 commits into from
Closed

Setup tests for driving with Enzyme #905

wants to merge 14 commits into from

Conversation

ChrisRackauckas
Copy link
Member

No description provided.

@ChrisRackauckas
Copy link
Member Author

MWE:

using OrdinaryDiffEq, SciMLSensitivity, Zygote, Enzyme
prob = ODEProblem((u, p, t) -> u .* p, [2.0], (0.0, 1.0), [3.0])

struct senseloss{T}
    sense::T
end
function (f::senseloss)(u0p)
    sum(solve(prob, Tsit5(), u0 = u0p[1:1], p = u0p[2:2], abstol = 1e-12,
        reltol = 1e-12, saveat = 0.1, sensealg = f.sense))
end
u0p = [2.0, 3.0]
dup = Zygote.gradient(senseloss(InterpolatingAdjoint()), u0p)[1]
Enzyme.gradient(Reverse, senseloss(InterpolatingAdjoint()), u0p)  dup

@ChrisRackauckas
Copy link
Member Author

The error is:

ERROR: Duplicated Returns not yet handled
Stacktrace:
 [1] autodiff(#unused#::ReverseMode{false, FFIABI}, f::Const{senseloss{InterpolatingAdjoint{0, true, Val{:central}, Nothing}}}, #unused#::Type{Duplicated{Any}}, args::Duplicated{Vector{Float64}})
   @ Enzyme C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\Enzyme.jl:205
 [2] autodiff
   @ C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\Enzyme.jl:236 [inlined]
 [3] autodiff
   @ C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\Enzyme.jl:222 [inlined]
 [4] gradient(#unused#::ReverseMode{false, FFIABI}, f::senseloss{InterpolatingAdjoint{0, true, Val{:central}, Nothing}}, x::Vector{Float64})
   @ Enzyme C:\Users\accou\.julia\packages\Enzyme\VS5jo\src\Enzyme.jl:809
 [5] top-level scope
   @ c:\Users\accou\OneDrive\Computer\Desktop\test.jl:94

Not sure why it would throw that since the result is a scalar and I can't get a stacktrace on it.

@codecov
Copy link

codecov bot commented Sep 23, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (bd6e048) 59.16% compared to head (d95e542) 59.59%.

❗ Current head d95e542 differs from pull request most recent head e3f5850. Consider uploading reports for the commit e3f5850 to get more accurate results

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #905      +/-   ##
==========================================
+ Coverage   59.16%   59.59%   +0.43%     
==========================================
  Files          19       20       +1     
  Lines        4481     4482       +1     
==========================================
+ Hits         2651     2671      +20     
+ Misses       1830     1811      -19     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@wsmoses
Copy link

wsmoses commented Sep 24, 2023

Julia failed to auto deduce type, claiming the return type was any.

What happens if you do Enzyme.autodiff(Reverse, senseloss(InterpolatingAdjoint()), Active, Duplicated(u0p, du0p))

@ChrisRackauckas
Copy link
Member Author

Interesting, so it seems Enzyme has issues with remake. But ignoring that, the issue boils down to:

using OrdinaryDiffEq, SciMLSensitivity, Zygote, Enzyme
struct senseloss{T}
    sense::T
end
function (f::senseloss)(u0p)
	prob = ODEProblem{false}((u, p, t) -> u .* p, u0p[1:1], (0.0, 1.0), u0p[2:2])
    sum(solve(prob, Tsit5(), abstol = 1e-12, reltol = 1e-12, saveat = 0.1))
end
u0p = [2.0, 3.0]
du0p = zeros(2)
dup = Zygote.gradient(senseloss(InterpolatingAdjoint()), u0p)[1]
Enzyme.autodiff(Reverse, senseloss(InterpolatingAdjoint()), Active, Duplicated(u0p, du0p))  dup

@code_warntype senseloss(InterpolatingAdjoint())(u0p)

I can code_warntype and see that it infers now. But it kills my terminal with its segfault so I can't seem to actually be able to read the stacktrace.

@frankschae
Copy link
Member

Here's the output from @code_warntype senseloss(InterpolatingAdjoint())(u0p) from Chris' MWE

MethodInstance for (::senseloss{InterpolatingAdjoint{0, true, Val{:central}, Nothing}})(::Vector{Float64})
  from (f::senseloss)(u0p) @ Main ~/Enzyme905/MWE.jl:13
Arguments
  f::senseloss{InterpolatingAdjoint{0, true, Val{:central}, Nothing}}
  u0p::Vector{Float64}
Locals
  #5::var"#5#6"
  prob::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float64}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#5#6", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}
Body::Float64
1%1  = Core.apply_type(Main.ODEProblem, false)::Core.Const(ODEProblem{false})
│         (#5 = %new(Main.:(var"#5#6")))%3  = #5::Core.Const(var"#5#6"())%4  = (1:1)::Core.Const(1:1)
│   %5  = Base.getindex(u0p, %4)::Vector{Float64}%6  = Core.tuple(0.0, 1.0)::Core.Const((0.0, 1.0))
│   %7  = (2:2)::Core.Const(2:2)
│   %8  = Base.getindex(u0p, %7)::Vector{Float64}
│         (prob = (%1)(%3, %5, %6, %8))
│   %10 = Main.Tsit5()::Core.Const(Tsit5(; stage_limiter! = trivial_limiter!, step_limiter! = trivial_limiter!, thread = static(false),))
│   %11 = (:abstol, :reltol, :saveat)::Core.Const((:abstol, :reltol, :saveat))
│   %12 = Core.apply_type(Core.NamedTuple, %11)::Core.Const(NamedTuple{(:abstol, :reltol, :saveat)})
│   %13 = Core.tuple(1.0e-12, 1.0e-12, 0.1)::Core.Const((1.0e-12, 1.0e-12, 0.1))
│   %14 = (%12)(%13)::Core.Const((abstol = 1.0e-12, reltol = 1.0e-12, saveat = 0.1))
│   %15 = Core.kwcall(%14, Main.solve, prob, %10)::ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float64}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#5#6", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, SciMLBase.AutoSpecialize, var"#5#6", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5ConstantCache, Nothing}, SciMLBase.DEStats, Nothing}
│   %16 = Main.sum(%15)::Float64
└──       return %16

and here's Enzyme.autodiff(Reverse, senseloss(InterpolatingAdjoint()), Active, Duplicated(u0p, du0p)) ≈ dup with Enzyme Master

 Warning: TypeAnalysisDepthLimit
│ {[]:Pointer, [0]:Pointer, [0,0]:Pointer, [0,0,0]:Pointer, [0,0,0,0]:Pointer, [0,0,0,0,0]:Pointer, [0,0,0,0,0,-1]:Float@double, [0,0,0,0,8]:Integer, [0,0,0,0,9]:Integer, [0,0,0,0,10]:Integer, [0,0,0,0,11]:Integer, [0,0,0,0,12]:Integer, [0,0,0,0,13]:Integer, [0,0,0,0,14]:Integer, [0,0,0,0,15]:Integer, [0,0,0,0,16]:Integer, [0,0,0,0,17]:Integer, [0,0,0,0,18]:Integer, [0,0,0,0,19]:Integer, [0,0,0,0,20]:Integer, [0,0,0,0,21]:Integer, [0,0,0,0,22]:Integer, [0,0,0,0,23]:Integer, [0,0,0,0,24]:Integer, [0,0,0,0,25]:Integer, [0,0,0,0,26]:Integer, [0,0,0,0,27]:Integer, [0,0,0,0,28]:Integer, [0,0,0,0,29]:Integer, [0,0,0,0,30]:Integer, [0,0,0,0,31]:Integer, [0,0,0,0,32]:Integer, [0,0,0,0,33]:Integer, [0,0,0,0,34]:Integer, [0,0,0,0,35]:Integer, [0,0,0,0,36]:Integer, [0,0,0,0,37]:Integer, [0,0,0,0,38]:Integer, [0,0,0,0,39]:Integer, [0,0,8]:Integer, [0,0,9]:Integer, [0,0,10]:Integer, [0,0,11]:Integer, [0,0,12]:Integer, [0,0,13]:Integer, [0,0,14]:Integer, [0,0,15]:Integer, [0,0,16]:Integer, [0,0,17]:Integer, [0,0,18]:Integer, [0,0,19]:Integer, [0,0,20]:Integer, [0,0,21]:Integer, [0,0,22]:Integer, [0,0,23]:Integer, [0,0,24]:Integer, [0,0,25]:Integer, [0,0,26]:Integer, [0,0,27]:Integer, [0,0,28]:Integer, [0,0,29]:Integer, [0,0,30]:Integer, [0,0,31]:Integer, [0,0,32]:Integer, [0,0,33]:Integer, [0,0,34]:Integer, [0,0,35]:Integer, [0,0,36]:Integer, [0,0,37]:Integer, [0,0,38]:Integer, [0,0,39]:Integer, [8]:Integer, [9]:Integer, [10]:Integer, [11]:Integer, [12]:Integer, [13]:Integer, [14]:Integer, [15]:Integer, [16]:Integer, [17]:Integer, [18]:Integer, [19]:Integer, [20]:Integer, [21]:Integer, [22]:Integer, [23]:Integer, [24]:Integer, [25]:Integer, [26]:Integer, [27]:Integer, [28]:Integer, [29]:Integer, [30]:Integer, [31]:Integer, [32]:Integer, [33]:Integer, [34]:Integer, [35]:Integer, [36]:Integer, [37]:Integer, [38]:Integer, [39]:Integer}
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/U36Ed/src/utils.jl:59
┌ Warning: TypeAnalysisDepthLimit
│ {[]:Pointer, [0]:Pointer, [0,0]:Pointer, [0,0,0]:Pointer, [0,0,0,0]:Pointer, [0,0,0,0,0]:Pointer, [0,0,0,0,0,-1]:Float@double, [0,0,0,0,8]:Integer, [0,0,0,0,9]:Integer, [0,0,0,0,10]:Integer, [0,0,0,0,11]:Integer, [0,0,0,0,12]:Integer, [0,0,0,0,13]:Integer, [0,0,0,0,14]:Integer, [0,0,0,0,15]:Integer, [0,0,0,0,16]:Integer, [0,0,0,0,17]:Integer, [0,0,0,0,18]:Integer, [0,0,0,0,19]:Integer, [0,0,0,0,20]:Integer, [0,0,0,0,21]:Integer, [0,0,0,0,22]:Integer, [0,0,0,0,23]:Integer, [0,0,0,0,24]:Integer, [0,0,0,0,25]:Integer, [0,0,0,0,26]:Integer, [0,0,0,0,27]:Integer, [0,0,0,0,28]:Integer, [0,0,0,0,29]:Integer, [0,0,0,0,30]:Integer, [0,0,0,0,31]:Integer, [0,0,0,0,32]:Integer, [0,0,0,0,33]:Integer, [0,0,0,0,34]:Integer, [0,0,0,0,35]:Integer, [0,0,0,0,36]:Integer, [0,0,0,0,37]:Integer, [0,0,0,0,38]:Integer, [0,0,0,0,39]:Integer, [0,0,8]:Integer, [0,0,9]:Integer, [0,0,10]:Integer, [0,0,11]:Integer, [0,0,12]:Integer, [0,0,13]:Integer, [0,0,14]:Integer, [0,0,15]:Integer, [0,0,16]:Integer, [0,0,17]:Integer, [0,0,18]:Integer, [0,0,19]:Integer, [0,0,20]:Integer, [0,0,21]:Integer, [0,0,22]:Integer, [0,0,23]:Integer, [0,0,24]:Integer, [0,0,25]:Integer, [0,0,26]:Integer, [0,0,27]:Integer, [0,0,28]:Integer, [0,0,29]:Integer, [0,0,30]:Integer, [0,0,31]:Integer, [0,0,32]:Integer, [0,0,33]:Integer, [0,0,34]:Integer, [0,0,35]:Integer, [0,0,36]:Integer, [0,0,37]:Integer, [0,0,38]:Integer, [0,0,39]:Integer, [8]:Integer, [9]:Integer, [10]:Integer, [11]:Integer, [12]:Integer, [13]:Integer, [14]:Integer, [15]:Integer, [16]:Integer, [17]:Integer, [18]:Integer, [19]:Integer, [20]:Integer, [21]:Integer, [22]:Integer, [23]:Integer, [24]:Integer, [25]:Integer, [26]:Integer, [27]:Integer, [28]:Integer, [29]:Integer, [30]:Integer, [31]:Integer, [32]:Integer, [33]:Integer, [34]:Integer, [35]:Integer, [36]:Integer, [37]:Integer, [38]:Integer, [39]:Integer}
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/U36Ed/src/utils.jl:59
┌ Warning: TypeAnalysisDepthLimit
│ {[]:Pointer, [0]:Pointer, [0,0]:Pointer, [0,0,0]:Pointer, [0,0,0,0]:Pointer, [0,0,0,0,-1]:Float@double, [0,0,0,8]:Integer, [0,0,0,9]:Integer, [0,0,0,10]:Integer, [0,0,0,11]:Integer, [0,0,0,12]:Integer, [0,0,0,13]:Integer, [0,0,0,14]:Integer, [0,0,0,15]:Integer, [0,0,0,16]:Integer, [0,0,0,17]:Integer, [0,0,0,18]:Integer, [0,0,0,19]:Integer, [0,0,0,20]:Integer, [0,0,0,21]:Integer, [0,0,0,22]:Integer, [0,0,0,23]:Integer, [0,0,0,24]:Integer, [0,0,0,25]:Integer, [0,0,0,26]:Integer, [0,0,0,27]:Integer, [0,0,0,28]:Integer, [0,0,0,29]:Integer, [0,0,0,30]:Integer, [0,0,0,31]:Integer, [0,0,0,32]:Integer, [0,0,0,33]:Integer, [0,0,0,34]:Integer, [0,0,0,35]:Integer, [0,0,0,36]:Integer, [0,0,0,37]:Integer, [0,0,0,38]:Integer, [0,0,0,39]:Integer, [0,8]:Integer, [0,9]:Integer, [0,10]:Integer, [0,11]:Integer, [0,12]:Integer, [0,13]:Integer, [0,14]:Integer, [0,15]:Integer, [0,16]:Integer, [0,17]:Integer, [0,18]:Integer, [0,19]:Integer, [0,20]:Integer, [0,21]:Integer, [0,22]:Integer, [0,23]:Integer, [0,24]:Integer, [0,25]:Integer, [0,26]:Integer, [0,27]:Integer, [0,28]:Integer, [0,29]:Integer, [0,30]:Integer, [0,31]:Integer, [0,32]:Integer, [0,33]:Integer, [0,34]:Integer, [0,35]:Integer, [0,36]:Integer, [0,37]:Integer, [0,38]:Integer, [0,39]:Integer, [8]:Pointer, [8,0]:Pointer, [8,0,-1]:Float@double, [8,8]:Integer, [8,9]:Integer, [8,10]:Integer, [8,11]:Integer, [8,12]:Integer, [8,13]:Integer, [8,14]:Integer, [8,15]:Integer, [8,16]:Integer, [8,17]:Integer, [8,18]:Integer, [8,19]:Integer, [8,20]:Integer, [8,21]:Integer, [8,22]:Integer, [8,23]:Integer, [8,24]:Integer, [8,25]:Integer, [8,26]:Integer, [8,27]:Integer, [8,28]:Integer, [8,29]:Integer, [8,30]:Integer, [8,31]:Integer, [8,32]:Integer, [8,33]:Integer, [8,34]:Integer, [8,35]:Integer, [8,36]:Integer, [8,37]:Integer, [8,38]:Integer, [8,39]:Integer, [16]:Pointer, [16,0]:Pointer, [16,0,0]:Pointer, [16,0,0,0]:Pointer, [16,0,0,0,0]:Pointer, [16,0,0,0,0,0]:Pointer, [16,0,0,0,0,8]:Integer, [16,0,0,0,0,9]:Integer, [16,0,0,0,0,10]:Integer, [16,0,0,0,0,11]:Integer, [16,0,0,0,0,12]:Integer, [16,0,0,0,0,13]:Integer, [16,0,0,0,0,14]:Integer, [16,0,0,0,0,15]:Integer, [16,0,0,0,0,16]:Integer, [16,0,0,0,0,17]:Integer, [16,0,0,0,0,18]:Integer, [16,0,0,0,0,19]:Integer, [16,0,0,0,0,20]:Integer, [16,0,0,0,0,21]:Integer, [16,0,0,0,0,22]:Integer, [16,0,0,0,0,23]:Integer, [16,0,0,0,0,24]:Integer, [16,0,0,0,0,25]:Integer, [16,0,0,0,0,26]:Integer, [16,0,0,0,0,27]:Integer, [16,0,0,0,0,28]:Integer, [16,0,0,0,0,29]:Integer, [16,0,0,0,0,30]:Integer, [16,0,0,0,0,31]:Integer, [16,0,0,0,0,32]:Integer, [16,0,0,0,0,33]:Integer, [16,0,0,0,0,34]:Integer, [16,0,0,0,0,35]:Integer, [16,0,0,0,0,36]:Integer, [16,0,0,0,0,37]:Integer, [16,0,0,0,0,38]:Integer, [16,0,0,0,0,39]:Integer, [16,0,0,8]:Integer, [16,0,0,9]:Integer, [16,0,0,10]:Integer, [16,0,0,11]:Integer, [16,0,0,12]:Integer, [16,0,0,13]:Integer, [16,0,0,14]:Integer, [16,0,0,15]:Integer, [16,0,0,16]:Integer, [16,0,0,17]:Integer, [16,0,0,18]:Integer, [16,0,0,19]:Integer, [16,0,0,20]:Integer, [16,0,0,21]:Integer, [16,0,0,22]:Integer, [16,0,0,23]:Integer, [16,0,0,24]:Integer, [16,0,0,25]:Integer, [16,0,0,26]:Integer, [16,0,0,27]:Integer, [16,0,0,28]:Integer, [16,0,0,29]:Integer, [16,0,0,30]:Integer, [16,0,0,31]:Integer, [16,0,0,32]:Integer, [16,0,0,33]:Integer, [16,0,0,34]:Integer, [16,0,0,35]:Integer, [16,0,0,36]:Integer, [16,0,0,37]:Integer, [16,0,0,38]:Integer, [16,0,0,39]:Integer, [16,8]:Integer, [16,9]:Integer, [16,10]:Integer, [16,11]:Integer, [16,12]:Integer, [16,13]:Integer, [16,14]:Integer, [16,15]:Integer, [16,16]:Integer, [16,17]:Integer, [16,18]:Integer, [16,19]:Integer, [16,20]:Integer, [16,21]:Integer, [16,22]:Integer, [16,23]:Integer, [16,24]:Integer, [16,25]:Integer, [16,26]:Integer, [16,27]:Integer, [16,28]:Integer, [16,29]:Integer, [16,30]:Integer, [16,31]:Integer, [16,32]:Integer, [16,33]:Integer, [16,34]:Integer, [16,35]:Integer, [16,36]:Integer, [16,37]:Integer, [16,38]:Integer, [16,39]:Integer, [24]:Pointer, [24,0]:Integer, [24,8]:Pointer, [24,8,0]:Pointer, [24,8,0,-1]:Float@double, [24,8,8]:Integer, [24,8,9]:Integer, [24,8,10]:Integer, [24,8,11]:Integer, [24,8,12]:Integer, [24,8,13]:Integer, [24,8,14]:Integer, [24,8,15]:Integer, [24,8,16]:Integer, [24,8,17]:Integer, [24,8,18]:Integer, [24,8,19]:Integer, [24,8,20]:Integer, [24,8,21]:Integer, [24,8,22]:Integer, [24,8,23]:Integer, [24,8,24]:Integer, [24,8,25]:Integer, [24,8,26]:Integer, [24,8,27]:Integer, [24,8,28]:Integer, [24,8,29]:Integer, [24,8,30]:Integer, [24,8,31]:Integer, [24,8,32]:Integer, [24,8,33]:Integer, [24,8,34]:Integer, [24,8,35]:Integer, [24,8,36]:Integer, [24,8,37]:Integer, [24,8,38]:Integer, [24,8,39]:Integer, [24,16]:Float@double, [24,24]:Float@double, [24,32]:Pointer, [24,32,0]:Pointer, [24,32,0,-1]:Float@double, [24,32,8]:Integer, [24,32,9]:Integer, [24,32,10]:Integer, [24,32,11]:Integer, [24,32,12]:Integer, [24,32,13]:Integer, [24,32,14]:Integer, [24,32,15]:Integer, [24,32,16]:Integer, [24,32,17]:Integer, [24,32,18]:Integer, [24,32,19]:Integer, [24,32,20]:Integer, [24,32,21]:Integer, [24,32,22]:Integer, [24,32,23]:Integer, [24,32,24]:Integer, [24,32,25]:Integer, [24,32,26]:Integer, [24,32,27]:Integer, [24,32,28]:Integer, [24,32,29]:Integer, [24,32,30]:Integer, [24,32,31]:Integer, [24,32,32]:Integer, [24,32,33]:Integer, [24,32,34]:Integer, [24,32,35]:Integer, [24,32,36]:Integer, [24,32,37]:Integer, [24,32,38]:Integer, [24,32,39]:Integer, [32]:Integer, [40]:Pointer, [40,0]:Pointer, [40,0,0]:Pointer, [40,0,0,0]:Pointer, [40,0,0,0,-1]:Float@double, [40,0,0,8]:Integer, [40,0,0,9]:Integer, [40,0,0,10]:Integer, [40,0,0,11]:Integer, [40,0,0,12]:Integer, [40,0,0,13]:Integer, [40,0,0,14]:Integer, [40,0,0,15]:Integer, [40,0,0,16]:Integer, [40,0,0,17]:Integer, [40,0,0,18]:Integer, [40,0,0,19]:Integer, [40,0,0,20]:Integer, [40,0,0,21]:Integer, [40,0,0,22]:Integer, [40,0,0,23]:Integer, [40,0,0,24]:Integer, [40,0,0,25]:Integer, [40,0,0,26]:Integer, [40,0,0,27]:Integer, [40,0,0,28]:Integer, [40,0,0,29]:Integer, [40,0,0,30]:Integer, [40,0,0,31]:Integer, [40,0,0,32]:Integer, [40,0,0,33]:Integer, [40,0,0,34]:Integer, [40,0,0,35]:Integer, [40,0,0,36]:Integer, [40,0,0,37]:Integer, [40,0,0,38]:Integer, [40,0,0,39]:Integer, [40,8]:Integer, [40,9]:Integer, [40,10]:Integer, [40,11]:Integer, [40,12]:Integer, [40,13]:Integer, [40,14]:Integer, [40,15]:Integer, [40,16]:Integer, [40,17]:Integer, [40,18]:Integer, [40,19]:Integer, [40,20]:Integer, [40,21]:Integer, [40,22]:Integer, [40,23]:Integer, [40,24]:Integer, [40,25]:Integer, [40,26]:Integer, [40,27]:Integer, [40,28]:Integer, [40,29]:Integer, [40,30]:Integer, [40,31]:Integer, [40,32]:Integer, [40,33]:Integer, [40,34]:Integer, [40,35]:Integer, [40,36]:Integer, [40,37]:Integer, [40,38]:Integer, [40,39]:Integer, [48]:Pointer, [48,0]:Pointer, [48,0,-1]:Float@double, [48,8]:Integer, [48,9]:Integer, [48,10]:Integer, [48,11]:Integer, [48,12]:Integer, [48,13]:Integer, [48,14]:Integer, [48,15]:Integer, [48,16]:Integer, [48,17]:Integer, [48,18]:Integer, [48,19]:Integer, [48,20]:Integer, [48,21]:Integer, [48,22]:Integer, [48,23]:Integer, [48,24]:Integer, [48,25]:Integer, [48,26]:Integer, [48,27]:Integer, [48,28]:Integer, [48,29]:Integer, [48,30]:Integer, [48,31]:Integer, [48,32]:Integer, [48,33]:Integer, [48,34]:Integer, [48,35]:Integer, [48,36]:Integer, [48,37]:Integer, [48,38]:Integer, [48,39]:Integer, [56]:Pointer, [56,0]:Pointer, [56,0,0]:Pointer, [56,0,0,0]:Pointer, [56,0,0,0,0]:Pointer, [56,0,0,0,0,0]:Pointer, [56,0,0,0,0,8]:Integer, [56,0,0,0,0,9]:Integer, [56,0,0,0,0,10]:Integer, [56,0,0,0,0,11]:Integer, [56,0,0,0,0,12]:Integer, [56,0,0,0,0,13]:Integer, [56,0,0,0,0,14]:Integer, [56,0,0,0,0,15]:Integer, [56,0,0,0,0,16]:Integer, [56,0,0,0,0,17]:Integer, [56,0,0,0,0,18]:Integer, [56,0,0,0,0,19]:Integer, [56,0,0,0,0,20]:Integer, [56,0,0,0,0,21]:Integer, [56,0,0,0,0,22]:Integer, [56,0,0,0,0,23]:Integer, [56,0,0,0,0,24]:Integer, [56,0,0,0,0,25]:Integer, [56,0,0,0,0,26]:Integer, [56,0,0,0,0,27]:Integer, [56,0,0,0,0,28]:Integer, [56,0,0,0,0,29]:Integer, [56,0,0,0,0,30]:Integer, [56,0,0,0,0,31]:Integer, [56,0,0,0,0,32]:Integer, [56,0,0,0,0,33]:Integer, [56,0,0,0,0,34]:Integer, [56,0,0,0,0,35]:Integer, [56,0,0,0,0,36]:Integer, [56,0,0,0,0,37]:Integer, [56,0,0,0,0,38]:Integer, [56,0,0,0,0,39]:Integer, [56,0,0,8]:Integer, [56,0,0,9]:Integer, [56,0,0,10]:Integer, [56,0,0,11]:Integer, [56,0,0,12]:Integer, [56,0,0,13]:Integer, [56,0,0,14]:Integer, [56,0,0,15]:Integer, [56,0,0,16]:Integer, [56,0,0,17]:Integer, [56,0,0,18]:Integer, [56,0,0,19]:Integer, [56,0,0,20]:Integer, [56,0,0,21]:Integer, [56,0,0,22]:Integer, [56,0,0,23]:Integer, [56,0,0,24]:Integer, [56,0,0,25]:Integer, [56,0,0,26]:Integer, [56,0,0,27]:Integer, [56,0,0,28]:Integer, [56,0,0,29]:Integer, [56,0,0,30]:Integer, [56,0,0,31]:Integer, [56,0,0,32]:Integer, [56,0,0,33]:Integer, [56,0,0,34]:Integer, [56,0,0,35]:Integer, [56,0,0,36]:Integer, [56,0,0,37]:Integer, [56,0,0,38]:Integer, [56,0,0,39]:Integer, [56,8]:Integer, [56,9]:Integer, [56,10]:Integer, [56,11]:Integer, [56,12]:Integer, [56,13]:Integer, [56,14]:Integer, [56,15]:Integer, [56,16]:Integer, [56,17]:Integer, [56,18]:Integer, [56,19]:Integer, [56,20]:Integer, [56,21]:Integer, [56,22]:Integer, [56,23]:Integer, [56,24]:Integer, [56,25]:Integer, [56,26]:Integer, [56,27]:Integer, [56,28]:Integer, [56,29]:Integer, [56,30]:Integer, [56,31]:Integer, [56,32]:Integer, [56,33]:Integer, [56,34]:Integer, [56,35]:Integer, [56,36]:Integer, [56,37]:Integer, [56,38]:Integer, [56,39]:Integer, [64]:Integer, [72]:Integer, [80]:Integer, [81]:Integer, [82]:Integer, [83]:Integer, [84]:Integer, [85]:Integer, [86]:Integer, [87]:Integer, [88]:Pointer, [88,0]:Integer, [88,1]:Integer, [88,2]:Integer, [88,3]:Integer, [88,4]:Integer, [88,5]:Integer, [88,6]:Integer, [88,7]:Integer, [88,8]:Integer, [88,9]:Integer, [88,10]:Integer, [88,11]:Integer, [88,12]:Integer, [88,13]:Integer, [88,14]:Integer, [88,15]:Integer, [88,16]:Integer, [88,17]:Integer, [88,18]:Integer, [88,19]:Integer, [88,20]:Integer, [88,21]:Integer, [88,22]:Integer, [88,23]:Integer, [88,24]:Integer, [88,25]:Integer, [88,26]:Integer, [88,27]:Integer, [88,28]:Integer, [88,29]:Integer, [88,30]:Integer, [88,31]:Integer, [88,32]:Integer, [88,33]:Integer, [88,34]:Integer, [88,35]:Integer, [88,36]:Integer, [88,37]:Integer, [88,38]:Integer, [88,39]:Integer, [88,40]:Integer, [88,41]:Integer, [88,42]:Integer, [88,43]:Integer, [88,44]:Integer, [88,45]:Integer, [88,46]:Integer, [88,47]:Integer, [88,48]:Integer, [88,49]:Integer, [88,50]:Integer, [88,51]:Integer, [88,52]:Integer, [88,53]:Integer, [88,54]:Integer, [88,55]:Integer, [88,56]:Integer, [88,57]:Integer, [88,58]:Integer, [88,59]:Integer, [88,60]:Integer, [88,61]:Integer, [88,62]:Integer, [88,63]:Integer, [88,64]:Integer, [88,65]:Integer, [88,66]:Integer, [88,67]:Integer, [88,68]:Integer, [88,69]:Integer, [88,70]:Integer, [88,71]:Integer, [88,72]:Integer, [88,73]:Integer, [88,74]:Integer, [88,75]:Integer, [88,76]:Integer, [88,77]:Integer, [88,78]:Integer, [88,79]:Integer, [88,80]:Float@double}
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/U36Ed/src/utils.jl:59
warning: didn't implement memmove, using memcpy as fallback which can result in errors
warning: didn't implement memmove, using memcpy as fallback which can result in errors
warning: didn't implement memmove, using memcpy as fallback which can result in errors
ERROR: Enzyme execution failed.
Enzyme: No custom reverse rule was appliable for Tuple{@NamedTuple{abstol::Float64, reltol::Float64, saveat::Float64}, typeof(EnzymeCore.EnzymeRules.reverse), EnzymeCore.EnzymeRules.ConfigWidth{1, true, true, (false, true, false, true, true, false)}, Const{typeof(DiffEqBase.solve_up)}, Type{Duplicated{ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float64}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#5#6", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, SciMLBase.AutoSpecialize, var"#5#6", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5ConstantCache, Nothing}, SciMLBase.DEStats, Nothing}}}, Any, Duplicated{ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float64}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#5#6", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}}, Const{Nothing}, Duplicated{Vector{Float64}}, Duplicated{Vector{Float64}}, Const{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}}
Stacktrace:
 [1] #solve#40
   @ ~/.julia/packages/DiffEqBase/nKsvb/src/solve.jl:977
 [2] solve
   @ ~/.julia/packages/DiffEqBase/nKsvb/src/solve.jl:967
 [3] senseloss
 [4] senseloss

Stacktrace:
  [1] throwerr(cstr::Cstring)
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:1342
  [2] augmented_primal
    @ ~/.julia/packages/DiffEqBase/nKsvb/ext/DiffEqBaseEnzymeExt.jl:9 [inlined]
  [3] #solve#40
    @ ~/.julia/packages/DiffEqBase/nKsvb/src/solve.jl:977 [inlined]
  [4] solve
    @ ~/.julia/packages/DiffEqBase/nKsvb/src/solve.jl:967 [inlined]
  [5] senseloss
    @ ~/Enzyme905/MWE.jl:15 [inlined]
  [6] senseloss
    @ ~/Enzyme905/MWE.jl:0 [inlined]
  [7] diffejulia_senseloss_7763_inner_1wrap
    @ ~/Enzyme905/MWE.jl:0
  [8] macro expansion
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:5301 [inlined]
  [9] enzyme_call
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:4979 [inlined]
 [10] CombinedAdjointThunk
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:4944 [inlined]
 [11] autodiff(::ReverseMode{false, FFIABI}, f::Const{senseloss{…}}, ::Type{Active}, args::Duplicated{Vector{…}})
    @ Enzyme ~/.julia/dev/Enzyme/src/Enzyme.jl:215
 [12] autodiff(::ReverseMode{false, FFIABI}, ::senseloss{InterpolatingAdjoint{…}}, ::Type, ::Duplicated{Vector{…}})
    @ Enzyme ~/.julia/dev/Enzyme/src/Enzyme.jl:224
 [13] top-level scope
Some type information was truncated. Use `show(err)` to see complete types.

Interestingly, when I ran it again, it slightly changed. I get

ERROR: Enzyme execution failed.
Enzyme: No custom reverse rule was appliable for Tuple{@NamedTuple{abstol::Float64, reltol::Float64, saveat::Float64}, typeof(EnzymeCore.EnzymeRules.reverse), EnzymeCore.EnzymeRules.ConfigWidth{1, true, true, (false, true, false, true, true, false)}, Const{typeof(DiffEqBase.solve_up)}, Type{Duplicated{ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float64}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#5#6", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, SciMLBase.AutoSpecialize, var"#5#6", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5ConstantCache, Nothing}, SciMLBase.DEStats, Nothing}}}, Any, Duplicated{ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float64}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#5#6", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}}, Const{Nothing}, Duplicated{Vector{Float64}}, Duplicated{Vector{Float64}}, Const{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}}
Stacktrace:
 [1] #solve#40
   @ ~/.julia/packages/DiffEqBase/nKsvb/src/solve.jl:977
 [2] solve
   @ ~/.julia/packages/DiffEqBase/nKsvb/src/solve.jl:967
 [3] senseloss
   @ ~/Enzyme905/MWE.jl:15
 [4] senseloss
   @ ~/Enzyme905/MWE.jl:0

Stacktrace:
  [1] throwerr(cstr::Cstring)
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:1342
  [2] augmented_primal
    @ ~/.julia/packages/DiffEqBase/nKsvb/ext/DiffEqBaseEnzymeExt.jl:9 [inlined]
  [3] #solve#40
    @ ~/.julia/packages/DiffEqBase/nKsvb/src/solve.jl:977 [inlined]
  [4] solve
    @ ~/.julia/packages/DiffEqBase/nKsvb/src/solve.jl:967 [inlined]
  [5] senseloss
    @ ~/Enzyme905/MWE.jl:15 [inlined]
  [6] senseloss
    @ ~/Enzyme905/MWE.jl:0 [inlined]
  [7] diffejulia_senseloss_7763_inner_1wrap
    @ ~/Enzyme905/MWE.jl:0
  [8] macro expansion
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:5301 [inlined]
  [9] enzyme_call
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:4979 [inlined]
 [10] CombinedAdjointThunk
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:4944 [inlined]
 [11] autodiff(::ReverseMode{false, FFIABI}, f::Const{senseloss{…}}, ::Type{Active}, args::Duplicated{Vector{…}})
    @ Enzyme ~/.julia/dev/Enzyme/src/Enzyme.jl:215
 [12] autodiff(::ReverseMode{false, FFIABI}, ::senseloss{InterpolatingAdjoint{…}}, ::Type, ::Duplicated{Vector{…}})
    @ Enzyme ~/.julia/dev/Enzyme/src/Enzyme.jl:224
 [13] top-level scope
Some type information was truncated. Use `show(err)` to see complete types.

No segfault!

@wsmoses
Copy link

wsmoses commented Dec 19, 2023

That error message basically says "you implemented a custom rule for this function, but not for these activity states". So this presumably is an issue in SciML to complete the rule.

@wsmoses
Copy link

wsmoses commented Dec 19, 2023

I think the error is on this line:

https://github.com/SciML/DiffEqBase.jl/blob/37814405e9b8ed5f31f099cdd3c00ada02c0a25d/ext/DiffEqBaseEnzymeExt.jl#L33C136-L33C136

It should probably be "Type{Duplicated{RT}}" rahter than "Type{<:Duplicated{RT}}"

@ChrisRackauckas
Copy link
Member Author

Locally I made that change and it didn't make a difference. Type{<:Duplicated{RT}} captures strictly more cases than Type{Duplicated{RT}} so I'm not surprised, I would've been surprised if that did work 😅

Can we get an error message that says which activity is not handled? I'm pretty confused as to what it might be calling.

@wsmoses
Copy link

wsmoses commented Dec 27, 2023

I think you need to release SciML/DiffEqBase.jl#982 for this to pass. Not sure, but I'm not confident that the test would use your manifest here.

Project.toml Outdated Show resolved Hide resolved
@ChrisRackauckas
Copy link
Member Author

The manifest does work, but because it's harmless I merged and tagged and so now this is setup without a manifest. Same issue.

@ChrisRackauckas
Copy link
Member Author

There's a tracker issue that just came up in here that we can ignore for now, but the Enzyme one is more intriguing.

https://github.com/SciML/SciMLSensitivity.jl/actions/runs/7385603005/job/20090674633?pr=905#step:6:814

TypeError: in typeassert, expected SciMLBase.ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, SciMLBase.ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float64}, SciMLBase.ODEFunction{false, SciMLBase.AutoSpecialize, typeof(Main.var"##Alternative AD Frontend#335".odef), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEq.Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{SciMLBase.ODEFunction{false, SciMLBase.AutoSpecialize, typeof(Main.var"##Alternative AD Frontend#335".odef), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5ConstantCache, Nothing}, SciMLBase.DEStats, Nothing}, got a value of type SciMLBase.ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Nothing, SciMLBase.ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float64}, SciMLBase.ODEFunction{false, SciMLBase.AutoSpecialize, typeof(Main.var"##Alternative AD Frontend#335".odef), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEq.Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, SciMLBase.SensitivityInterpolation{StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Vector{Vector{Float64}}}, SciMLBase.DEStats, Nothing}

This is by design. In the ODE solution object, sol(t) does an interpolation using internal solver values. If one uses the adjoint method then you cannot differentiate with respect to these values (it actually becomes equivalent to having to differentiate the solver), and so when someone is using the adjoint method we disable this by replacing the interpolation object from the solution before passing it on when in reverse mode AD https://github.com/SciML/SciMLBase.jl/blob/350e3d81db91d8f901a9c5d8e8977adb3a9273d5/src/solutions/ode_solutions.jl#L388-L404 so that way it throws a customized error on interpolations which violate what is differentiable.

Is there a way to make Enzyme understand that?

@wsmoses
Copy link

wsmoses commented Jan 2, 2024

No idea where that's coming from but Enzyme has an invariant that the primal and shadow are the same type/shape.

@ChrisRackauckas
Copy link
Member Author

No idea where that's coming from but Enzyme has an invariant that the primal and shadow are the same type/shape.

But that cannot or should not be true in this case for correctness checking. How can we get around that?

@ChrisRackauckas
Copy link
Member Author

I set it up with runtime checks so that the type invariant can hold, but I get an error message with no information deep in the compiler? 😅

Alternative AD Frontend: Error During Test at /home/runner/.julia/packages/SafeTestsets/raUNr/src/SafeTestsets.jl:30
  Got exception outside of a @test
  LoadError: AssertionError: value_type(lhs_v) == value_type(rhs_v)

@wsmoses
Copy link

wsmoses commented Jan 7, 2024

@ChrisRackauckas I just merged EnzymeAD/Enzyme.jl#1237 to main which should provide some more information, can you rerun with that. In essence something is going wrong when analyzing the GC behavior of the custom rule code being injected and we're throwing up our hands in advance rather than risk a segfault.

@ArnoStrouwen
Copy link
Member

Running the following example, which uses a lower level interface than the above examples:

using OrdinaryDiffEq, SciMLSensitivity, Zygote, Enzyme
struct senseloss{T}
    sense::T
end
function (f::senseloss)(u0p)
	prob = ODEProblem{false}((u, p, t) -> u .* p, u0p[1:1], (0.0, 1.0), u0p[2:2])
    integrator = init(prob, Tsit5())
    integrator.u[1]
end
u0p = [2.0, 3.0]
du0p = zeros(2)
#dup = Zygote.gradient(senseloss(InterpolatingAdjoint()), u0p)
Enzyme.autodiff(Reverse, senseloss(InterpolatingAdjoint()), Active, Duplicated(u0p, du0p))
Enzyme.autodiff(Reverse, senseloss(InterpolatingAdjoint()), Active, Duplicated(u0p, du0p)) # need to run twice to get below output

gives

julia> Enzyme.autodiff(Reverse, senseloss(InterpolatingAdjoint()), Active, Duplicated(u0p, du0p)) # need to run twice to get below output
ERROR: Enzyme execution failed.
Enzyme compilation failed.
Current scope: 
; Function Attrs: mustprogress willreturn
define double @preprocess_julia_senseloss_3007_inner.1([1 x { i8, i8 }] %0, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %1) local_unnamed_addr #165 !dbg !8783 {
entry:
  %newstruct.i = alloca [1 x [2 x i64]], align 8
  %newstruct2.i = alloca [1 x [2 x i64]], align 8
  %2 = alloca [2 x [2 x {} addrspace(10)*]], align 8
  %3 = bitcast [1 x [2 x i64]]* %newstruct.i to i8*
  call void @llvm.lifetime.start.p0i8(i64 noundef 16, i8* noundef nonnull dereferenceable(16) %3) #183
  %4 = bitcast [1 x [2 x i64]]* %newstruct2.i to i8*
  call void @llvm.lifetime.start.p0i8(i64 noundef 16, i8* noundef nonnull dereferenceable(16) %4) #183
  %5 = bitcast [2 x [2 x {} addrspace(10)*]]* %2 to i8*
  call void @llvm.lifetime.start.p0i8(i64 noundef 32, i8* noundef nonnull dereferenceable(32) %5) #183
  %6 = call {}*** @julia.get_pgcstack() #183
  %current_task1.i16 = getelementptr inbounds {}**, {}*** %6, i64 -14
  %current_task1.i = bitcast {}*** %current_task1.i16 to {}**
  %ptls_field.i17 = getelementptr inbounds {}**, {}*** %6, i64 2
  %7 = bitcast {}*** %ptls_field.i17 to i64***
  %ptls_load.i1819 = load i64**, i64*** %7, align 8, !tbaa !159
  %8 = getelementptr inbounds i64*, i64** %ptls_load.i1819, i64 2
  %safepoint.i = load i64*, i64** %8, align 8, !tbaa !163, !invariant.load !158
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint.i) #183, !dbg !8784
  fence syncscope("singlethread") seq_cst
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* noundef nonnull writeonly align 8 dereferenceable(16) %3, i8* noundef nonnull readonly align 8 dereferenceable(16) bitcast ([2 x i64]* @_j_const1 to i8*), i64 noundef 16, i1 noundef false) #183, !dbg !8786, !tbaa !301, !alias.scope !4349, !noalias !8789
  %9 = addrspacecast {} addrspace(10)* %1 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !8792
  %arraylen_ptr.i = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %9, i64 0, i32 1, !dbg !8792
  %arraylen.i = load i64, i64 addrspace(11)* %arraylen_ptr.i, align 8, !dbg !8792, !tbaa !204, !range !207, !alias.scope !208, !noalias !209
  %.not = icmp eq i64 %arraylen.i, 0, !dbg !8798
  br i1 %.not, label %L19.i, label %L18.i, !dbg !8801

L18.i:                                            ; preds = %entry
  %10 = call noalias nonnull {} addrspace(10)* @ijl_alloc_array_1d({} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 139751658212592 to {}*) to {} addrspace(10)*), i64 noundef 1) #184, !dbg !8802
  call fastcc void @julia__copyto_impl__3570({} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %10, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %1, i64 noundef signext 1, i64 noundef signext 1) #183, !dbg !8807
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* noundef nonnull align 8 dereferenceable(16) %4, i8* noundef nonnull align 8 dereferenceable(16) bitcast ([2 x i64]* @_j_const2 to i8*), i64 noundef 16, i1 noundef false) #183, !dbg !8786, !tbaa !301, !alias.scope !4349, !noalias !8789
  %arraylen4.i = load i64, i64 addrspace(11)* %arraylen_ptr.i, align 8, !dbg !8792, !tbaa !204, !range !207, !alias.scope !208, !noalias !209
  %11 = icmp ult i64 %arraylen4.i, 2, !dbg !8798
  br i1 %11, label %L51.i, label %L50.i, !dbg !8801

L19.i:                                            ; preds = %entry
  %12 = addrspacecast [1 x [2 x i64]]* %newstruct.i to [1 x [2 x i64]] addrspace(11)*, !dbg !8801
  call fastcc void @julia_throw_boundserror_3015({} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %1, [1 x [2 x i64]] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(16) %12) #185, !dbg !8801
  unreachable, !dbg !8801

L50.i:                                            ; preds = %L18.i
  %13 = call noalias nonnull {} addrspace(10)* @ijl_alloc_array_1d({} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 139751658212592 to {}*) to {} addrspace(10)*), i64 noundef 1) #184, !dbg !8802
  call fastcc void @julia__copyto_impl__3570({} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %13, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %1, i64 noundef signext 2, i64 noundef signext 1) #183, !dbg !8807
  call fastcc void @julia____14_3594() #183, !dbg !8809
  %newstruct7.i = call noalias nonnull dereferenceable(40) {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task1.i, i64 noundef 40, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 139749124209808 to {}*) to {} addrspace(10)*)) #186, !dbg !8812
  %14 = addrspacecast {} addrspace(10)* %newstruct7.i to {} addrspace(10)* addrspace(11)*, !dbg !8812
  %15 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %14, i64 1, !dbg !8812
  store {} addrspace(10)* null, {} addrspace(10)* addrspace(11)* %15, align 8, !dbg !8812, !tbaa !1751, !alias.scope !255, !noalias !8818
  %16 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %14, i64 4, !dbg !8812
  store {} addrspace(10)* null, {} addrspace(10)* addrspace(11)* %16, align 8, !dbg !8812, !tbaa !1751, !alias.scope !255, !noalias !8818
  %17 = bitcast {} addrspace(10)* %newstruct7.i to i8 addrspace(10)*, !dbg !8812
  store i8 1, i8 addrspace(10)* %17, align 8, !dbg !8812
  %18 = addrspacecast {} addrspace(10)* %newstruct7.i to i8 addrspace(11)*, !dbg !8812
  %19 = getelementptr inbounds i8, i8 addrspace(11)* %18, i64 8, !dbg !8812
  %20 = bitcast i8 addrspace(11)* %19 to {} addrspace(10)* addrspace(11)*, !dbg !8812
  store atomic {} addrspace(10)* %10, {} addrspace(10)* addrspace(11)* %20 release, align 8, !dbg !8812, !tbaa !1751, !alias.scope !255, !noalias !8818
  %21 = getelementptr inbounds i8, i8 addrspace(11)* %18, i64 16, !dbg !8812
  call void @llvm.memcpy.p11i8.p0i8.i64(i8 addrspace(11)* noundef align 8 dereferenceable(16) dereferenceable_or_null(24) %21, i8* noundef nonnull align 8 dereferenceable(16) bitcast ([2 x double]* @_j_const5 to i8*), i64 noundef 16, i1 noundef false) #183, !dbg !8812, !tbaa !2607, !alias.scope !255, !noalias !8818
  %22 = getelementptr inbounds i8, i8 addrspace(11)* %18, i64 32, !dbg !8812
  %23 = bitcast i8 addrspace(11)* %22 to {} addrspace(10)* addrspace(11)*, !dbg !8812
  store atomic {} addrspace(10)* %13, {} addrspace(10)* addrspace(11)* %23 release, align 8, !dbg !8812, !tbaa !1751, !alias.scope !255, !noalias !8818
  %unbox.i.unpack = load {} addrspace(10)*, {} addrspace(10)** inttoptr (i64 139749814316464 to {} addrspace(10)**), align 16, !dbg !8819, !tbaa !414, !alias.scope !255, !noalias !256
  %unbox.i.unpack20 = load {} addrspace(10)*, {} addrspace(10)** inttoptr (i64 139749814316472 to {} addrspace(10)**), align 8, !dbg !8819, !tbaa !414, !alias.scope !255, !noalias !256
  %.fca.0.0.gep = getelementptr inbounds [2 x [2 x {} addrspace(10)*]], [2 x [2 x {} addrspace(10)*]]* %2, i64 0, i64 0, i64 0, !dbg !8827
  store {} addrspace(10)* %10, {} addrspace(10)** %.fca.0.0.gep, align 8, !dbg !8827, !noalias !8828
  %.fca.0.1.gep = getelementptr inbounds [2 x [2 x {} addrspace(10)*]], [2 x [2 x {} addrspace(10)*]]* %2, i64 0, i64 0, i64 1, !dbg !8827
  store {} addrspace(10)* %13, {} addrspace(10)** %.fca.0.1.gep, align 8, !dbg !8827, !noalias !8828
  %.fca.1.0.gep = getelementptr inbounds [2 x [2 x {} addrspace(10)*]], [2 x [2 x {} addrspace(10)*]]* %2, i64 0, i64 1, i64 0, !dbg !8827
  store {} addrspace(10)* %unbox.i.unpack, {} addrspace(10)** %.fca.1.0.gep, align 8, !dbg !8827, !noalias !8828
  %.fca.1.1.gep = getelementptr inbounds [2 x [2 x {} addrspace(10)*]], [2 x [2 x {} addrspace(10)*]]* %2, i64 0, i64 1, i64 1, !dbg !8827
  store {} addrspace(10)* %unbox.i.unpack20, {} addrspace(10)** %.fca.1.1.gep, align 8, !dbg !8827, !noalias !8828
  %24 = addrspacecast [2 x [2 x {} addrspace(10)*]]* %2 to [2 x [2 x {} addrspace(10)*]] addrspace(11)*, !dbg !8827
  %25 = call fastcc nonnull {} addrspace(10)* @julia__get_concrete_problem_58_3589([2 x [2 x {} addrspace(10)*]] addrspace(11)* nocapture noundef nonnull readonly align 8 dereferenceable(32) %24, {} addrspace(10)* noundef nonnull align 8 dereferenceable(40) %newstruct7.i) #183, !dbg !8827
  %26 = addrspacecast {} addrspace(10)* %25 to i8 addrspace(11)*, !dbg !8829
  %aggregate_load_box.i.sroa.0.0..sroa_idx = getelementptr inbounds i8, i8 addrspace(11)* %26, i64 16, !dbg !8829
  %aggregate_load_box.i.sroa.0.0..sroa_cast = bitcast i8 addrspace(11)* %aggregate_load_box.i.sroa.0.0..sroa_idx to double addrspace(11)*, !dbg !8829
  %aggregate_load_box.i.sroa.3.0..sroa_idx = getelementptr inbounds i8, i8 addrspace(11)* %26, i64 24, !dbg !8829
  %aggregate_load_box.i.sroa.3.0..sroa_cast = bitcast i8 addrspace(11)* %aggregate_load_box.i.sroa.3.0..sroa_idx to double addrspace(11)*, !dbg !8829
  %aggregate_load_box.i.sroa.3.0.copyload = load double, double addrspace(11)* %aggregate_load_box.i.sroa.3.0..sroa_cast, align 1, !dbg !8829, !tbaa !301, !alias.scope !4349, !noalias !8789
  %aggregate_load_box10.i.sroa.0.0.copyload = load double, double addrspace(11)* %aggregate_load_box.i.sroa.0.0..sroa_cast, align 1, !dbg !8829, !tbaa !301, !alias.scope !4349, !noalias !8789
  %27 = fsub double %aggregate_load_box.i.sroa.3.0.copyload, %aggregate_load_box10.i.sroa.0.0.copyload, !dbg !8838
  %28 = call nonnull "enzyme_inactive" {}* @julia.pointer_from_objref({} addrspace(11)* noundef addrspacecast ({}* inttoptr (i64 139749438057424 to {}*) to {} addrspace(11)*)) #187, !dbg !8839
  %29 = bitcast {}* %28 to {} addrspace(10)**, !dbg !8839
  %30 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)** %29, i64 1, !dbg !8839
  %string_ptr.i = ptrtoint {} addrspace(10)** %30 to i64, !dbg !8839
  %31 = call nonnull {} addrspace(10)* @ijl_tagged_gensym(i64 %string_ptr.i, i64 noundef 14) #183, !dbg !8841
  %32 = call fastcc nonnull {} addrspace(10)* @julia____init_747_3017(double %27, {} addrspace(10)* nofree noundef nonnull %31, {} addrspace(10)* noundef nonnull align 8 dereferenceable(40) %25) #183, !dbg !8842
  %33 = addrspacecast {} addrspace(10)* %32 to i8 addrspace(11)*, !dbg !8843
  %getfield_addr13.i = getelementptr inbounds i8, i8 addrspace(11)* %33, i64 104, !dbg !8843
  %34 = bitcast i8 addrspace(11)* %getfield_addr13.i to {} addrspace(10)* addrspace(11)*, !dbg !8843
  %getfield14.i = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %34 unordered, align 8, !dbg !8843, !tbaa !1751, !alias.scope !255, !noalias !256, !nonnull !158, !dereferenceable !266, !align !267
  %35 = addrspacecast {} addrspace(10)* %getfield14.i to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !8845
  %arraylen_ptr15.i = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i16, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %35, i64 0, i32 1, !dbg !8845
  %arraylen16.i = load i64, i64 addrspace(11)* %arraylen_ptr15.i, align 8, !dbg !8845, !tbaa !204, !range !207, !alias.scope !208, !noalias !209
  %inbounds.i.not = icmp eq i64 %arraylen16.i, 0, !dbg !8845
  br i1 %inbounds.i.not, label %oob.i, label %julia_senseloss_3007_inner.exit, !dbg !8845

L51.i:                                            ; preds = %L18.i
  %36 = addrspacecast [1 x [2 x i64]]* %newstruct2.i to [1 x [2 x i64]] addrspace(11)*, !dbg !8801
  call fastcc void @julia_throw_boundserror_3015({} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %1, [1 x [2 x i64]] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(16) %36) #185, !dbg !8801
  unreachable, !dbg !8801

oob.i:                                            ; preds = %L50.i
  %errorbox.i = alloca i64, align 8, !dbg !8845
  store i64 1, i64* %errorbox.i, align 8, !dbg !8845, !noalias !8828
  %37 = addrspacecast {} addrspace(10)* %getfield14.i to {} addrspace(12)*, !dbg !8845
  call void @ijl_bounds_error_ints({} addrspace(12)* %37, i64* nonnull align 8 %errorbox.i, i64 1) #183, !dbg !8845
  unreachable, !dbg !8845

julia_senseloss_3007_inner.exit:                  ; preds = %L50.i
  %38 = addrspacecast {} addrspace(10)* %getfield14.i to double addrspace(13)* addrspace(11)*, !dbg !8845
  %arrayptr.i22 = load double addrspace(13)*, double addrspace(13)* addrspace(11)* %38, align 16, !dbg !8845, !tbaa !247, !alias.scope !8846, !noalias !209, !nonnull !158
  %arrayref.i = load double, double addrspace(13)* %arrayptr.i22, align 8, !dbg !8845, !tbaa !1172, !alias.scope !255, !noalias !256
  call void @llvm.lifetime.end.p0i8(i64 16, i8* nonnull %3) #183, !dbg !8847
  call void @llvm.lifetime.end.p0i8(i64 16, i8* nonnull %4) #183, !dbg !8847
  call void @llvm.lifetime.end.p0i8(i64 32, i8* nonnull %5) #183, !dbg !8847
  ret double %arrayref.i, !dbg !8848
}

No augmented forward pass found for ijl_tagged_gensym
 at context:   %31 = call nonnull {} addrspace(10)* @ijl_tagged_gensym(i64 %string_ptr.i, i64 noundef 14) #183, !dbg !290

Stacktrace:
  [1] gensym
    @ ./expr.jl:16
  [2] __init (repeats 5 times)
    @ ~/.julia/packages/OrdinaryDiffEq/2nLli/src/solve.jl:10
  [3] #init_call#30
    @ ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:529
  [4] init_call
    @ ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:503
  [5] #init_up#33
    @ ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:562
  [6] init_up
    @ ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:549
  [7] #init#31
    @ ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:542
  [8] init
    @ ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:533
  [9] senseloss
    @ ./REPL[4]:3
 [10] senseloss
    @ ./REPL[4]:0


Stacktrace:
  [1] throwerr(cstr::Cstring)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/Dd2LU/src/compiler.jl:1317
  [2] gensym
    @ ./expr.jl:16 [inlined]
  [3] __init (repeats 5 times)
    @ ~/.julia/packages/OrdinaryDiffEq/2nLli/src/solve.jl:10 [inlined]
  [4] #init_call#30
    @ ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:529 [inlined]
  [5] init_call
    @ ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:503 [inlined]
  [6] #init_up#33
    @ ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:562 [inlined]
  [7] init_up
    @ ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:549 [inlined]
  [8] #init#31
    @ ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:542 [inlined]
  [9] init
    @ ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:533 [inlined]
 [10] senseloss
    @ ./REPL[4]:3 [inlined]
 [11] senseloss
    @ ./REPL[4]:0 [inlined]
 [12] diffejulia_senseloss_3007_inner_1wrap
    @ ./REPL[4]:0
 [13] macro expansion
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/Dd2LU/src/compiler.jl:5299 [inlined]
 [14] enzyme_call
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/Dd2LU/src/compiler.jl:4977 [inlined]
 [15] CombinedAdjointThunk
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/Dd2LU/src/compiler.jl:4919 [inlined]
 [16] autodiff(::ReverseMode{false, FFIABI}, f::Const{senseloss{…}}, ::Type{Active}, args::Duplicated{Vector{…}})
    @ Enzyme ~/.julia/packages/Enzyme/Dd2LU/src/Enzyme.jl:215
 [17] autodiff(::ReverseMode{false, FFIABI}, ::senseloss{InterpolatingAdjoint{…}}, ::Type, ::Duplicated{Vector{…}})
    @ Enzyme ~/.julia/packages/Enzyme/Dd2LU/src/Enzyme.jl:224
 [18] top-level scope
    @ REPL[8]:1
Some type information was truncated. Use `show(err)` to see complete types.

I don't have much experience reading output like this, but it seems to point to:
https://github.com/SciML/OrdinaryDiffEq.jl/blob/master/src/solve.jl#L66 ?

@wsmoses
Copy link

wsmoses commented Jan 27, 2024 via email

@ArnoStrouwen
Copy link
Member

Indeed, with that branch of Enzyme, I now get the different error:

julia> Enzyme.autodiff(Reverse, senseloss(InterpolatingAdjoint()), Active, Duplicated(u0p, du0p)) # need to run twice to get below output
ERROR: Enzyme execution failed.
Mismatched activity for:   store atomic {} addrspace(10)* %1, {} addrspace(10)* addrspace(11)* %108 release, align 8, !dbg !380, !tbaa !213, !alias.scope !217, !noalias !242 const val: {} addrspace(10)* %1
Type tree: {[-1]:Pointer}
You may be using a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/#Activity-of-temporary-storage). If not, please open an issue, and either rewrite this variable to not be conditionally active or use Enzyme.API.runtimeActivity!(true) as a workaround for now

Stacktrace:
 [1] DEOptions
   @ ~/.julia/packages/OrdinaryDiffEq/2nLli/src/integrators/type.jl:4
 [2] #__init#747
   @ ~/.julia/packages/OrdinaryDiffEq/2nLli/src/solve.jl:366

Stacktrace:
  [1] throwerr(cstr::Cstring)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/cL7zZ/src/compiler.jl:1319
  [2] DEOptions
    @ ~/.julia/packages/OrdinaryDiffEq/2nLli/src/integrators/type.jl:4 [inlined]
  [3] #__init#747
    @ ~/.julia/packages/OrdinaryDiffEq/2nLli/src/solve.jl:366
  [4] __init (repeats 5 times)
    @ ~/.julia/packages/OrdinaryDiffEq/2nLli/src/solve.jl:10 [inlined]
  [5] #init_call#30
    @ ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:529 [inlined]
  [6] init_call
    @ ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:503 [inlined]
  [7] #init_up#33
    @ ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:562 [inlined]
  [8] init_up
    @ ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:549 [inlined]
  [9] #init#31
    @ ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:542 [inlined]
 [10] init
    @ ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:533 [inlined]
 [11] senseloss
    @ ~/SciML/SciMLSensitivity.jl/src/tester.jl:7 [inlined]
 [12] senseloss
    @ ~/SciML/SciMLSensitivity.jl/src/tester.jl:0 [inlined]
 [13] diffejulia_senseloss_3992_inner_1wrap
    @ ~/SciML/SciMLSensitivity.jl/src/tester.jl:0
 [14] macro expansion
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/cL7zZ/src/compiler.jl:5308 [inlined]
 [15] enzyme_call
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/cL7zZ/src/compiler.jl:4986 [inlined]
 [16] CombinedAdjointThunk
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/cL7zZ/src/compiler.jl:4928 [inlined]
 [17] autodiff(::ReverseMode{false, FFIABI}, f::Const{senseloss{…}}, ::Type{Active}, args::Duplicated{Vector{…}})
    @ Enzyme ~/.julia/packages/Enzyme/cL7zZ/src/Enzyme.jl:215
 [18] autodiff(::ReverseMode{false, FFIABI}, ::senseloss{InterpolatingAdjoint{…}}, ::Type, ::Duplicated{Vector{…}})
    @ Enzyme ~/.julia/packages/Enzyme/cL7zZ/src/Enzyme.jl:224
 [19] top-level scope
    @ REPL[14]:1
Some type information was truncated. Use `show(err)` to see complete types.

Curiously, with the below example, I got the correct answer (1.0, 0.0) once. But it is not reproducible.

using OrdinaryDiffEq, SciMLSensitivity, Zygote, Enzyme
function g(u0p)
	prob = ODEProblem{false}((u, p, t) -> u .* p, u0p[1:1], (0.0, 1.0), u0p[2:2])
    integrator = init(prob, Tsit5())
    integrator.u[1]
end
u0p = [2.0, 3.0]
du0p = zeros(2)
Enzyme.gradient(Reverse, g, u0p)

@ChrisRackauckas
Copy link
Member Author

It shouldn't be differentiating into the solve. That means it's not hitting the rule for some reason.

@ArnoStrouwen
Copy link
Member

Enzyme master works for me now!

using OrdinaryDiffEq, SciMLSensitivity, Zygote, Enzyme
struct senseloss{T}
    sense::T
end
function (f::senseloss)(u0p)
	prob = ODEProblem{false}((u, p, t) -> u .* p, u0p[1:1], (0.0, 1.0), u0p[2:2])
    sum(solve(prob, Tsit5(), abstol = 1e-12, reltol = 1e-12, saveat = 0.1))
end
u0p = [2.0, 3.0]
du0p = zeros(2)
dup = Zygote.gradient(senseloss(InterpolatingAdjoint()), u0p)[1]
Enzyme.autodiff(Reverse, senseloss(InterpolatingAdjoint()), Active, Duplicated(u0p, du0p)) 
du0p  dup # gives true

It still prints quite a few of warnings about TypeAnalysisDepthLimit, however.
It seems like adding the gensym rule was sufficient.

@ArnoStrouwen
Copy link
Member

When manually specifying the sensealg, instead of relying on the automatic algorithm (which chooses forward mode here), there is still an error, but this seems like an issue on the SciML side:

using OrdinaryDiffEq, SciMLSensitivity, Zygote, Enzyme
Enzyme.API.runtimeActivity!(true)
struct senseloss{T}
    sense::T
end
function (f::senseloss)(u0p)
	prob = ODEProblem{true}((du, u, p, t) -> begin du .= u .* p  end, u0p[1:1], (0.0, 1.0), u0p[2:2])
    sum(solve(prob, Tsit5(), abstol = 1e-12, reltol = 1e-12, saveat = 0.1, sensealg = f.sense)) # not specifying sensealg works
end
u0p = [2.0, 3.0]
du0p = zeros(2)
senseloss(InterpolatingAdjoint())(u0p)
dup = Zygote.gradient(senseloss(InterpolatingAdjoint()), u0p)[1]
Enzyme.autodiff(Reverse, senseloss(InterpolatingAdjoint()), Active, Duplicated(u0p, du0p)) 
du0p  dup #fails
ERROR: TypeError: in typeassert, expected ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float64}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#13#14", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, SciMLBase.AutoSpecialize, var"#13#14", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, Nothing, OrdinaryDiffEq.Tsit5ConstantCache, Nothing}, SciMLBase.DEStats, Nothing}, got a value of type ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Nothing, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float64}, ODEFunction{false, SciMLBase.AutoSpecialize, var"#13#14", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, SciMLBase.AutoSpecialize, var"#13#14", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, Nothing, OrdinaryDiffEq.Tsit5ConstantCache, Nothing}, SciMLBase.DEStats, Nothing}
Stacktrace:
  [1] #augmented_primal#1
    @ ~/.julia/packages/DiffEqBase/eLhx9/ext/DiffEqBaseEnzymeExt.jl:25
  [2] augmented_primal
    @ ~/.julia/packages/DiffEqBase/eLhx9/ext/DiffEqBaseEnzymeExt.jl:9 [inlined]
  [3] #solve#40
    @ ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:981 [inlined]
  [4] solve
    @ ~/.julia/packages/DiffEqBase/eLhx9/src/solve.jl:971 [inlined]
  [5] senseloss
    @ ./REPL[28]:3 [inlined]
  [6] senseloss
    @ ./REPL[28]:0 [inlined]
  [7] diffejulia_senseloss_18987_inner_1wrap
    @ ./REPL[28]:0
  [8] macro expansion
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/RN4Ys/src/compiler.jl:5308 [inlined]
  [9] enzyme_call
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/RN4Ys/src/compiler.jl:4986 [inlined]
 [10] CombinedAdjointThunk
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/RN4Ys/src/compiler.jl:4928 [inlined]
 [11] autodiff(::ReverseMode{false, FFIABI}, f::Const{senseloss{…}}, ::Type{Active}, args::Duplicated{Vector{…}})
    @ Enzyme ~/.julia/packages/Enzyme/RN4Ys/src/Enzyme.jl:215
 [12] autodiff(::ReverseMode{false, FFIABI}, ::senseloss{InterpolatingAdjoint{…}}, ::Type, ::Duplicated{Vector{…}})
    @ Enzyme ~/.julia/packages/Enzyme/RN4Ys/src/Enzyme.jl:224
 [13] top-level scope
    @ REPL[32]:1
Some type information was truncated. Use `show(err)` to see complete types.

@ChrisRackauckas
Copy link
Member Author

Can you diff the two types? What's actually different?

@ArnoStrouwen
Copy link
Member

Capture

@ChrisRackauckas
Copy link
Member Author

Interesting. I think after the t is subbed in for the saveat output it's turned into a range while the solvers always make it a vector. Force the t in the sensitivity_solution to be a vector

ChrisRackauckas and others added 2 commits February 5, 2024 03:55
Co-authored-by: Arno Strouwen <arno.strouwen@telenet.be>
@ArnoStrouwen
Copy link
Member

Enzyme + TrackerAdjoint has the issue that all the real numbers in the solution are tracked, while enzyme expects them not to be.
Screenshot_20240214_120355

@ArnoStrouwen
Copy link
Member

This one reads to me like it should be handled in Enzyme @wsmoses

  LoadError: Enzyme execution failed.
  Enzyme: Not yet implemented augmented forward for jl_f__apply_iterate (true, true, iterate, Core.apply_type, 7, 6)
  Stacktrace:
   [1] signature_type
     @ ./reflection.jl:962
   [2] _methods
     @ ./reflection.jl:1020
  
  Stacktrace:
    [1] throwerr(cstr::Cstring)
      @ Enzyme.Compiler ~/.julia/packages/Enzyme/M8KTx/src/compiler.jl:1287
    [2] signature_type
      @ ./reflection.jl:962 [inlined]
    [3] _methods
      @ ./reflection.jl:1020 [inlined]
    [4] augmented_julia__methods_39623wrap
      @ ./reflection.jl:0
    [5] macro expansion
      @ Enzyme.Compiler ~/.julia/packages/Enzyme/M8KTx/src/compiler.jl:5376 [inlined]
    [6] enzyme_call
      @ Enzyme.Compiler ~/.julia/packages/Enzyme/M8KTx/src/compiler.jl:5054 [inlined]
    [7] AugmentedForwardThunk
      @ Enzyme.Compiler ~/.julia/packages/Enzyme/M8KTx/src/compiler.jl:5007 [inlined]
    [8] runtime_generic_augfwd(activity::Type{Val{(false, false, false, false, false)}}, width::Val{1}, ModifiedBetween::Val{(true, true, true, true, true)}, RT::Val{@NamedTuple{1, 2, 3}}, f::typeof(Base._methods), df::Nothing, primal_1::typeof(Main.var"##Alternative AD Frontend#225".f_aug), shadow_1_1::Nothing, primal_2::Type{Tuple}, shadow_2_1::Nothing, primal_3::Int64, shadow_3_1::Nothing, primal_4::UInt64, shadow_4_1::Nothing)
      @ Enzyme.Compiler ~/.julia/packages/Enzyme/M8KTx/src/rules/jitrules.jl:179
    [9] methods
      @ ./reflection.jl:1072 [inlined]
   [10] augmented_julia_methods_39596wrap
      @ ./reflection.jl:0
   [11] macro expansion
      @ Enzyme.Compiler ~/.julia/packages/Enzyme/M8KTx/src/compiler.jl:5376 [inlined]
   [12] enzyme_call(::Val{false}, ::Ptr{Nothing}, ::Type{Enzyme.Compiler.AugmentedForwardThunk}, ::Type{Val{1}}, ::Val{true}, ::Type{Tuple{EnzymeCore.Const{typeof(Main.var"##Alternative AD Frontend#225".f_aug)}, EnzymeCore.Const{Type{Tuple}}, EnzymeCore.Const{Nothing}}}, ::Type{EnzymeCore.Duplicated{Base.MethodList}}, ::EnzymeCore.Const{typeof(methods)}, ::Type{@NamedTuple{1, 2, 3, 4, 5, 6::UInt64, 7, 8, 9, 10::Bool, 11, 12::Core.LLVMPtr{UInt64, 0}, 13::Core.LLVMPtr{Bool, 0}, 14::Core.LLVMPtr{Bool, 0}, 15}}, ::EnzymeCore.Const{typeof(Main.var"##Alternative AD Frontend#225".f_aug)}, ::EnzymeCore.Const{Type{Tuple}}, ::EnzymeCore.Const{Nothing})
      @ Enzyme.Compiler ~/.julia/packages/Enzyme/M8KTx/src/compiler.jl:5054
   [13] (::Enzyme.Compiler.AugmentedForwardThunk{Ptr{Nothing}, EnzymeCore.Const{typeof(methods)}, EnzymeCore.Duplicated{Base.MethodList}, Tuple{EnzymeCore.Const{typeof(Main.var"##Alternative AD Frontend#225".f_aug)}, EnzymeCore.Const{Type{Tuple}}, EnzymeCore.Const{Nothing}}, Val{1}, Val{true}(), @NamedTuple{1, 2, 3, 4, 5, 6::UInt64, 7, 8, 9, 10::Bool, 11, 12::Core.LLVMPtr{UInt64, 0}, 13::Core.LLVMPtr{Bool, 0}, 14::Core.LLVMPtr{Bool, 0}, 15}})(::EnzymeCore.Const{typeof(methods)}, ::EnzymeCore.Const{typeof(Main.var"##Alternative AD Frontend#225".f_aug)}, ::Vararg{Any})
      @ Enzyme.Compiler ~/.julia/packages/Enzyme/M8KTx/src/compiler.jl:5007
   [14] runtime_generic_augfwd(activity::Type{Val{(false, false, false, false)}}, width::Val{1}, ModifiedBetween::Val{(true, true, true, true)}, RT::Val{@NamedTuple{1, 2, 3}}, f::typeof(methods), df::Nothing, primal_1::typeof(Main.var"##Alternative AD Frontend#225".f_aug), shadow_1_1::Nothing, primal_2::Type{Tuple}, shadow_2_1::Nothing, primal_3::Nothing, shadow_3_1::Nothing)
      @ Enzyme.Compiler ~/.julia/packages/Enzyme/M8KTx/src/rules/jitrules.jl:179
   [15] methods (repeats 2 times)
      @ ./reflection.jl:1094 [inlined]
   [16] numargs
      @ ~/.julia/packages/SciMLBase/QSc1r/src/utils.jl:12 [inlined]
   [17] #isinplace#3
      @ ~/.julia/packages/SciMLBase/QSc1r/src/utils.jl:244 [inlined]
   [18] augmented_julia__isinplace_3_38717_inner_1wrap
      @ ~/.julia/packages/SciMLBase/QSc1r/src/utils.jl:0
   [19] macro expansion
      @ Enzyme.Compiler ~/.julia/packages/Enzyme/M8KTx/src/compiler.jl:5376 [inlined]
   [20] enzyme_call
      @ Enzyme.Compiler ~/.julia/packages/Enzyme/M8KTx/src/compiler.jl:5054 [inlined]
   [21] AugmentedForwardThunk
      @ Enzyme.Compiler ~/.julia/packages/Enzyme/M8KTx/src/compiler.jl:5007 [inlined]
   [22] runtime_generic_augfwd(activity::Type{Val{(false, false, false, false, false, false, false, false)}}, width::Val{1}, ModifiedBetween::Val{(true, true, true, true, true, true, true, true)}, RT::Val{@NamedTuple{1, 2, 3}}, f::SciMLBase.var"##isinplace#3", df::Nothing, primal_1::Bool, shadow_1_1::Nothing, primal_2::Bool, shadow_2_1::Nothing, primal_3::typeof(SciMLBase.isinplace), shadow_3_1::Nothing, primal_4::typeof(Main.var"##Alternative AD Frontend#225".f_aug), shadow_4_1::Nothing, primal_5::Int64, shadow_5_1::Nothing, primal_6::String, shadow_6_1::Nothing, primal_7::Bool, shadow_7_1::Nothing)
      @ Enzyme.Compiler ~/.julia/packages/Enzyme/M8KTx/src/rules/jitrules.jl:179
   [23] isinplace (repeats 2 times)
      @ ~/.julia/packages/SciMLBase/QSc1r/src/utils.jl:242 [inlined]
   [24] #ODEProblem#307
      @ ~/.julia/packages/SciMLBase/QSc1r/src/problems/ode_problems.jl:188 [inlined]
   [25] augmented_julia__ODEProblem_307_38622_inner_1wrap
      @ ~/.julia/packages/SciMLBase/QSc1r/src/problems/ode_problems.jl:0
   [26] macro expansion
      @ ~/.julia/packages/Enzyme/M8KTx/src/compiler.jl:5376 [inlined]
   [27] enzyme_call
      @ ~/.julia/packages/Enzyme/M8KTx/src/compiler.jl:5054 [inlined]
   [28] AugmentedForwardThunk
      @ ~/.julia/packages/Enzyme/M8KTx/src/compiler.jl:5007 [inlined]
   [29] runtime_generic_augfwd(activity::Type{Val{(false, false, false, false, false, true, true)}}, width::Val{1}, ModifiedBetween::Val{(true, true, true, true, true, true, true)}, RT::Val{@NamedTuple{1, 2, 3}}, f::SciMLBase.var"##ODEProblem#307", df::Nothing, primal_1::@Kwargs{alg::OrdinaryDiffEq.Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, sensealg::InterpolatingAdjoint{0, true, Val{:central}, Nothing}}, shadow_1_1::Nothing, primal_2::Type{SciMLBase.ODEProblem}, shadow_2_1::Nothing, primal_3::typeof(Main.var"##Alternative AD Frontend#225".f_aug), shadow_3_1::Nothing, primal_4::Matrix{Float64}, shadow_4_1::Nothing, primal_5::Tuple{Float64, Float64}, shadow_5_1::Base.RefValue{Tuple{Float64, Float64}}, primal_6::Vector{Float64}, shadow_6_1::Vector{Float64})
      @ Enzyme.Compiler ~/.julia/packages/Enzyme/M8KTx/src/rules/jitrules.jl:179
   [30] ODEProblem
      @ ~/.julia/packages/SciMLBase/QSc1r/src/problems/ode_problems.jl:187 [inlined]
   [31] ODEProblem
      @ ~/.julia/packages/SciMLBase/QSc1r/src/problems/ode_problems.jl:0 [inlined]
   [32] augmented_julia_ODEProblem_38465_inner_1wrap
      @ ~/.julia/packages/SciMLBase/QSc1r/src/problems/ode_problems.jl:0
   [33] macro expansion
      @ Enzyme.Compiler ~/.julia/packages/Enzyme/M8KTx/src/compiler.jl:5376 [inlined]
   [34] enzyme_call
      @ Enzyme.Compiler ~/.julia/packages/Enzyme/M8KTx/src/compiler.jl:5054 [inlined]
   [35] AugmentedForwardThunk
      @ Enzyme.Compiler ~/.julia/packages/Enzyme/M8KTx/src/compiler.jl:5007 [inlined]
   [36] runtime_generic_augfwd(activity::Type{Val{(false, true, false, false, false, false, true)}}, width::Val{1}, ModifiedBetween::Val{(true, true, true, true, true, true, true)}, RT::Val{@NamedTuple{1, 2, 3}}, f::typeof(Core.kwcall), df::Nothing, primal_1::@NamedTuple{alg::OrdinaryDiffEq.Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, sensealg::InterpolatingAdjoint{0, true, Val{:central}, Nothing}}, shadow_1_1::@NamedTuple{alg::OrdinaryDiffEq.Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, sensealg::InterpolatingAdjoint{0, true, Val{:central}, Nothing}}, primal_2::Type{SciMLBase.ODEProblem}, shadow_2_1::Nothing, primal_3::typeof(Main.var"##Alternative AD Frontend#225".f_aug), shadow_3_1::Nothing, primal_4::Matrix{Float64}, shadow_4_1::Nothing, primal_5::Tuple{Float64, Float64}, shadow_5_1::Nothing, primal_6::Vector{Float64}, shadow_6_1::Vector{Float64})
      @ Enzyme.Compiler ~/.julia/packages/Enzyme/M8KTx/src/rules/jitrules.jl:179
   [37] loss
      @ ~/SciML/SciMLSensitivity.jl/test/alternative_ad_frontend.jl:171 [inlined]
   [38] augmented_julia_loss_38280wrap
      @ ~/SciML/SciMLSensitivity.jl/test/alternative_ad_frontend.jl:0
   [39] macro expansion
      @ Enzyme.Compiler ~/.julia/packages/Enzyme/M8KTx/src/compiler.jl:5376 [inlined]
   [40] enzyme_call
      @ Enzyme.Compiler ~/.julia/packages/Enzyme/M8KTx/src/compiler.jl:5054 [inlined]
   [41] AugmentedForwardThunk
      @ Enzyme.Compiler ~/.julia/packages/Enzyme/M8KTx/src/compiler.jl:5007 [inlined]
   [42] autodiff(::EnzymeCore.ReverseMode{false, EnzymeCore.FFIABI}, f::EnzymeCore.Const{typeof(Main.var"##Alternative AD Frontend#225".loss)}, ::Type{EnzymeCore.Active}, args::EnzymeCore.Duplicated{Vector{Float64}})
      @ Enzyme ~/.julia/packages/Enzyme/M8KTx/src/Enzyme.jl:198
   [43] autodiff
      @ ~/.julia/packages/Enzyme/M8KTx/src/Enzyme.jl:224 [inlined]
   [44] gradient(::EnzymeCore.ReverseMode{false, EnzymeCore.FFIABI}, f::Function, x::Vector{Float64})
      @ Enzyme ~/.julia/packages/Enzyme/M8KTx/src/Enzyme.jl:805
   [45] top-level scope
      @ ~/SciML/SciMLSensitivity.jl/test/alternative_ad_frontend.jl:191
   [46] include(mod::Module, _path::String)
      @ Base ./Base.jl:495
   [47] include(x::String)
      @ Main.var"##Alternative AD Frontend#225" ~/.julia/packages/SafeTestsets/raUNr/src/SafeTestsets.jl:28
   [48] macro expansion
      @ ~/.julia/packages/SafeTestsets/raUNr/src/SafeTestsets.jl:24 [inlined]
   [49] macro expansion
      @ ~/.julia/juliaup/julia-1.10.0+0.x64.linux.gnu/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
   [50] top-level scope
      @ ~/.julia/packages/SafeTestsets/raUNr/src/SafeTestsets.jl:24
   [51] eval(m::Module, e::Any)
      @ Core ./boot.jl:385
   [52] macro expansion
      @ ~/.julia/packages/SafeTestsets/raUNr/src/SafeTestsets.jl:28 [inlined]
   [53] macro expansion
      @ ./timing.jl:279 [inlined]
   [54] macro expansion
      @ ~/SciML/SciMLSensitivity.jl/test/runtests.jl:15 [inlined]
   [55] macro expansion
      @ ~/.julia/juliaup/julia-1.10.0+0.x64.linux.gnu/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
  --- the last 2 lines are repeated 1 more time ---
   [58] macro expansion
      @ ~/SciML/SciMLSensitivity.jl/test/runtests.jl:13 [inlined]
   [59] macro expansion
      @ ./timing.jl:279 [inlined]
   [60] top-level scope
      @ ~/SciML/SciMLSensitivity.jl/test/runtests.jl:269
   [61] include(fname::String)
      @ Base.MainInclude ./client.jl:489
   [62] top-level scope
      @ none:6
   [63] eval
      @ Core ./boot.jl:385 [inlined]
   [64] exec_options(opts::Base.JLOptions)
      @ Base ./client.jl:291
  in expression starting at /home/arno/SciML/SciMLSensitivity.jl/test/alternative_ad_frontend.jl:191

Co-authored-by: Arno Strouwen <arno.strouwen@telenet.be>
@ChrisRackauckas
Copy link
Member Author

Enzyme + TrackerAdjoint has the issue that all the real numbers in the solution are tracked, while enzyme expects them not to be.

The return is supposed to strip tracking information https://github.com/SciML/SciMLSensitivity.jl/blob/master/src/concrete_solve.jl#L1198-L1199. We probably just need to strip off more, though it's very strange that is showing tracker on the u part.

@ArnoStrouwen
Copy link
Member

The final set of CI failures are of the form:

julia> ReverseDiff.gradient(senseloss(ForwardSensitivity()),
           u0p)dup
ERROR: MethodError: no method matching increment_deriv!(::ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, ::ChainRulesCore.NotImplemented, ::Int64)

Closest candidates are:
  increment_deriv!(::AbstractArray, ::Any)
   @ ReverseDiff ~/.julia/packages/ReverseDiff/UJhiD/src/derivatives/propagation.jl:38
  increment_deriv!(::ReverseDiff.TrackedArray, ::Real, ::Any)
   @ ReverseDiff ~/.julia/packages/ReverseDiff/UJhiD/src/derivatives/propagation.jl:34
  increment_deriv!(::AbstractArray, ::Real, ::Any)
   @ ReverseDiff ~/.julia/packages/ReverseDiff/UJhiD/src/derivatives/propagation.jl:36
  ...

Stacktrace:
  [1] increment_deriv!
    @ ~/.julia/packages/ReverseDiff/UJhiD/src/derivatives/propagation.jl:40 [inlined]
  [2] _add_to_deriv!
    @ ~/.julia/packages/ReverseDiff/UJhiD/src/macros.jl:381 [inlined]
  [3] _broadcast_getindex_evalf
    @ ./broadcast.jl:709 [inlined]
  [4] _broadcast_getindex
    @ ./broadcast.jl:682 [inlined]
  [5] (::Base.Broadcast.var"#31#32")(k::Any)
    @ Base.Broadcast ./broadcast.jl:1118 [inlined]
  [6] macro expansion
    @ ./ntuple.jl:72 [inlined]
  [7] ntuple
    @ ./ntuple.jl:69 [inlined]
  [8] copy
    @ ./broadcast.jl:1118 [inlined]
  [9] materialize
    @ ./broadcast.jl:903 [inlined]
 [10] special_reverse_exec!(instruction::ReverseDiff.SpecialInstruction{typeof(DiffEqBase.solve_up), Tuple{…}, ReverseDiff.TrackedArray{…}, Tuple{…}})
    @ DiffEqBaseReverseDiffExt ~/.julia/packages/ReverseDiff/UJhiD/src/macros.jl:218
 [11] reverse_exec!(instruction::ReverseDiff.SpecialInstruction{typeof(DiffEqBase.solve_up), Tuple{…}, ReverseDiff.TrackedArray{…}, Tuple{…}})
    @ ReverseDiff ~/.julia/packages/ReverseDiff/UJhiD/src/tape.jl:93
 [12] reverse_pass!(tape::Vector{ReverseDiff.AbstractInstruction})
    @ ReverseDiff ~/.julia/packages/ReverseDiff/UJhiD/src/tape.jl:87
 [13] reverse_pass!
    @ ~/.julia/packages/ReverseDiff/UJhiD/src/api/tape.jl:36 [inlined]
 [14] seeded_reverse_pass!(result::Vector{…}, output::ReverseDiff.TrackedReal{…}, input::ReverseDiff.TrackedArray{…}, tape::ReverseDiff.GradientTape{…})
    @ ReverseDiff ~/.julia/packages/ReverseDiff/UJhiD/src/api/utils.jl:31
 [15] seeded_reverse_pass!(result::Vector{Float64}, t::ReverseDiff.GradientTape{senseloss{ForwardSensitivity{…}}, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{…}, Vector{…}}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}})
    @ ReverseDiff ~/.julia/packages/ReverseDiff/UJhiD/src/api/tape.jl:47
 [16] gradient(f::senseloss{ForwardSensitivity{0, true, Val{:central}}}, input::Vector{Float64}, cfg::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}})
    @ ReverseDiff ~/.julia/packages/ReverseDiff/UJhiD/src/api/gradients.jl:24
 [17] gradient(f::senseloss{ForwardSensitivity{0, true, Val{:central}}}, input::Vector{Float64})
    @ ReverseDiff ~/.julia/packages/ReverseDiff/UJhiD/src/api/gradients.jl:22
 [18] top-level scope
    @ REPL[25]:1

while instead a ForwardSensitivityOutOfPlaceError is expected.

@ChrisRackauckas
Copy link
Member Author

That means ReverseDiff isn't getting captured into the adjoint system?

@ArnoStrouwen
Copy link
Member

For the tracker error, in this specific example it should be:

julia> [Tracker.data.(u) for u in sol.u]
11-element Vector{Vector{Float64}}:
 [2.0]
 [2.6997176151520326]
 [3.644237600781067]
 [4.919206222313961]
 [6.6402338454731975]
 [8.96337814067623]
 [12.099294928825994]
 [16.33233982513538]
 [22.046352761283476]
 [29.759463449746153]
 [40.17107384637566]

But that might break the non-Enzyme AD systems.

For the reversediff error, I have no clue. How can this PR break that, it does not touch Enzyme at all?

@ChrisRackauckas
Copy link
Member Author

The return from the adjoint should be untracked, so that should be fine.

For the reversediff error, I have no clue. How can this PR break that, it does not touch Enzyme at all?

That one is really odd. I'd setup a master test and see where we are at.

Co-authored-by: Arno Strouwen <arno.strouwen@telenet.be>
@test_broken Enzyme.gradient(Reverse, senseloss(ReverseDiffAdjoint()), u0p) ≈ dup
@test Enzyme.gradient(Reverse, senseloss(TrackerAdjoint()), u0p) ≈ dup
@test Enzyme.gradient(Reverse, senseloss(ForwardDiffSensitivity()), u0p) ≈ dup
@test_throws SciMLSensitivity.ForwardSensitivityOutOfPlaceError ReverseDiff.gradient(Reverse, senseloss(ForwardSensitivity()),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why an oop error? The problem is in position. It gives a different error. Is Reversediff + iip + ForwardSensitivity known to be broken, or should this work?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this had more to do with ReverseDiff requiring AbstractArray returns, so reversediff is known to fail in this case (see it marked as failing in the rest of the file too)

@ChrisRackauckas
Copy link
Member Author

Merged as #1021

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants