-
-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Zygote Batch Support #39
Conversation
I'd be weary of merging this without figuring out how this is actually different: it looks like it's computing the same thing, just without the parallelism, which seems wrong. |
"wrong" as in there is a better way, or "wrong" as in incorrect? It is computing the same thing. It is equivelent to
|
But the first is computing incorrect values right? I think it's best to figure out why, instead of trying to just avoid the issue. |
If you call
directly for batch >0 you will get incorrect answers. I seems like the issue is the size of |
Here is a MWE showing the issue
|
src/Quadrature.jl
Outdated
dfdp = function (x,p) | ||
out = zeros(length(p),size(x,2)) | ||
for idx in 1:size(x,2) | ||
_,back = Zygote.pullback(p->prob.f(@view(x[:,idx]),p),p) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still feels like a hack, but the pullback
can get taken out of the loop w/o the view and then loop through back
accordingly.
Using MWE:
using Zygote
w(x,p) = x[1,:]*p[1] + x[2,:]*p[2]
batch = 5
_,back = Zygote.pullback(p->w(repeat([1,2],inner=(1,batch)),p),[3 2])
out = zeros(2, batch )
for idx in 1:batch
z = zeros(batch)
z[idx] = 1
out[:,idx] = back(z)[1]
end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You shouldn't need to pull back the pieces separately though: this is pullback of the identity matrix of a sparse Jacobian which is the matrix coloring problem that has a single color solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I understand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's https://mitmath.github.io/18337/lecture9/stiff_odes but with reverse mode.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using Zygote
w(x,p) = x[1,:]*p[1] + x[2,:]*p[2]
batch = 5
_,back = Zygote.pullback(p->w(repeat([1,2],inner=(1,batch)),p),[3 2])
out = zeros(2, batch )
z = zeros(batch)
for idx in 1:batch
z[idx] = 1
out[:,idx] = back(z)[1]
z[idx] = 0
end
out
is how I'd do it. It's not a hack but the right way to do it, because you have to pull back the basis elements if the function isn't sparse, and here it's not sparse in p
. This keeps the primal calculation down to a single time and is the equivalent to the forwarddiff chunk version.
This should reduce the primal calculation as we discussed for in and out of place. Currently crashes for |
Closes #37