-
-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ForwardDiff Batch Mode Support #29
Conversation
src/Quadrature.jl
Outdated
dfdp = function (dx,x,p) | ||
# dfdp = function (dx,x,p) | ||
dfdp = function (x,p) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to oop b/c unsure how to copy result to dx w/o mutating. If using Buffer
we need to allocate for the result anyway
dx = Zygote.Buffer(x) | ||
prob.f(dx,x,p) | ||
copy(dx) | ||
_dx = Zygote.Buffer(x, prob.nout, size(x,2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For some of the quadrature methods prob.batch
isn't adhered to. It looks like it serves more as a max batch number. Some methods "grow" the batch size. So, need to set the solution size accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Further more, Quadrature.jl runs a test calling prob.f([lb ub], p)
to test if the solution of the integrand is a Vector
. So if batch>0
, the code will always try two point batch first.
|
It looks like batch mode is now working for ForwardDiff, but Zygote still has issues. For R->R it is working, but for R^n->R, the first element of the gradient is the wrong value. Interestingly it is equal to the true solution * using Quadrature, Cuba, Cubature, Zygote, FiniteDiff, ForwardDiff, Test
### Batch Single dim
f(x,p) = x*p[1].+p[2]*p[3]
lb =1.0
ub = 3.0
p = [2.0, 3.0, 4.0]
prob = QuadratureProblem(f,lb,ub,p)
function testf3(lb,ub,p; f=f)
prob = QuadratureProblem(f,lb,ub,p, batch = 10, nout=1)
solve(prob, CubatureJLh(); reltol=1e-3,abstol=1e-3)[1]
end
dp1 = ForwardDiff.gradient(p->testf3(lb,ub,p),p)
dp2 = Zygote.gradient(p->testf3(lb,ub,p),p)[1]
dp3 = FiniteDiff.finite_difference_gradient(p->testf3(lb,ub,p),p)
@test dp1 ≈ dp3 #passes
@test dp2 ≈ dp3 #passes
### Batch multi dim
f(x,p) = x[1,:]*p[1].+p[2]*p[3]
lb =[1.0,1.0]
ub = [3.0,3.0]
p = [2.0, 3.0, 4.0]
prob = QuadratureProblem(f,lb,ub,p)
function testf3(lb,ub,p; f=f)
prob = QuadratureProblem(f,lb,ub,p, batch = 10, nout=1)
solve(prob, CubatureJLh(); reltol=1e-3,abstol=1e-3)[1]
end
dp1 = ForwardDiff.gradient(p->testf3(lb,ub,p),p)
dp2 = Zygote.gradient(p->testf3(lb,ub,p),p)[1]
dp3 = FiniteDiff.finite_difference_gradient(p->testf3(lb,ub,p),p)
@test dp1 ≈ dp3 # passes
@test dp2 ≈ dp3 # Fail [136.0,16.0,12.0] ≈ [8.0,16.0,12.0] |
change the title and lets merge this at least for now. That is... an odd Zygote behavior for sure haha. Probably some accidental referencing instead of copying. |
Hold off on merging. I am going to push an additional test first |
OK, it is ready. I added the single and multi dim tests from above and added |
I previously missed fix for iip. Should be corrected w/ test now. This now allows batch mode AD with DEU. |
Corrects the output size for Zygote Batch mode Adjoint.