diff --git a/Project.toml b/Project.toml index 70456c04c..82dbf331a 100644 --- a/Project.toml +++ b/Project.toml @@ -60,7 +60,7 @@ GPUArraysCore = "0.1" LinearSolve = "1" OrdinaryDiffEq = "6.19.1" Parameters = "0.12" -PreallocationTools = "0.4" +PreallocationTools = "0.4.4" QuadGK = "2.1" RandomNumbers = "1.5.3" RecursiveArrayTools = "2.4.2" diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index bed260381..67d16fad7 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -15,7 +15,7 @@ import Enzyme import GPUArraysCore using StaticArrays -import PreallocationTools: dualcache, get_tmp, DiffCache +import PreallocationTools: dualcache, get_tmp, DiffCache, LazyBufferCache import FunctionWrappersWrappers using Cassette, DiffRules diff --git a/src/reversediff.jl b/src/reversediff.jl index ee0afe9e1..b9dacf8e6 100644 --- a/src/reversediff.jl +++ b/src/reversediff.jl @@ -110,3 +110,14 @@ ReverseDiff.@grad function solve_up(prob, sensealg, u0, p, args...; kwargs...) end Array(out[1]), actual_adjoint end + +# PreallocationTools https://github.com/SciML/PreallocationTools.jl/issues/39 +function Base.getindex(b::LazyBufferCache, u::ReverseDiff.TrackedArray) + s = b.sizemap(size(u)) # required buffer size + T = ReverseDiff.TrackedArray + buf = get!(b.bufs, (T, s)) do + # declare type since b.bufs dictionary is untyped + similar(u, s) + end + return buf +end diff --git a/test/lazybuffer.jl b/test/lazybuffer.jl new file mode 100644 index 000000000..c43d8de5f --- /dev/null +++ b/test/lazybuffer.jl @@ -0,0 +1,44 @@ +using LinearAlgebra, OrdinaryDiffEq, Test, PreallocationTools +using Random, FiniteDiff, ForwardDiff, ReverseDiff, SciMLSensitivity, Zygote + +# see https://github.com/SciML/PreallocationTools.jl/issues/29 +@testset "VJP computation with LazyBuffer" begin + u0 = rand(2, 2) + p = rand(2, 2) + struct foo{T} + lbc::T + end + + f = foo(LazyBufferCache()) + + function (f::foo)(du, u, p, t) + tmp = f.lbc[u] + mul!(tmp, p, u) # avoid tmp = p*u + @. du = u + tmp + nothing + end + + prob = ODEProblem(f, u0, (0.0, 1.0), p) + + function loss(u0, p; sensealg = nothing) + _prob = remake(prob, u0 = u0, p = p) + _sol = solve(_prob, Tsit5(), sensealg = sensealg, saveat = 0.1, abstol = 1e-14, + reltol = 1e-14) + sum(abs2, _sol) + end + + loss(u0, p) + + du0 = FiniteDiff.finite_difference_gradient(u0 -> loss(u0, p), u0) + dp = FiniteDiff.finite_difference_gradient(p -> loss(u0, p), p) + Fdu0 = ForwardDiff.gradient(u0 -> loss(u0, p), u0) + Fdp = ForwardDiff.gradient(p -> loss(u0, p), p) + @test du0≈Fdu0 rtol=1e-8 + @test dp≈Fdp rtol=1e-8 + + Zdu0, Zdp = Zygote.gradient((u0, p) -> loss(u0, p; + sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP())), + u0, p) + @test du0≈Zdu0 rtol=1e-8 + @test dp≈Zdp rtol=1e-8 +end diff --git a/test/runtests.jl b/test/runtests.jl index ee77951da..57c6b2635 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -46,6 +46,7 @@ end @time @safetestset "Continuous adjoint params" begin include("adjoint_param.jl") end @time @safetestset "Continuous and discrete costs" begin include("mixed_costs.jl") end @time @safetestset "Fully Out of Place adjoint sensitivity" begin include("adjoint_oop.jl") end + @time @safetestset "Differentiate LazyBuffer with ReverseDiff" begin include("lazybuffer.jl") end end if GROUP == "All" || GROUP == "Core4"