Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zygote Batch Support #39

Merged
merged 4 commits into from
Sep 4, 2020
Merged

Zygote Batch Support #39

merged 4 commits into from
Sep 4, 2020

Conversation

agerlach
Copy link
Collaborator

@agerlach agerlach commented Sep 3, 2020

Closes #37

@ChrisRackauckas
Copy link
Member

I'd be weary of merging this without figuring out how this is actually different: it looks like it's computing the same thing, just without the parallelism, which seems wrong.

@agerlach
Copy link
Collaborator Author

agerlach commented Sep 4, 2020

"wrong" as in there is a better way, or "wrong" as in incorrect? It is computing the same thing. It is equivelent to

dfdp = function (x,p)
    _,back = Zygote.pullback(p->prob.f(x,p),p)
    back(y)[1]
end

if prob.batch > 0
    _dfdp = dfdp
    dfdp = function (x,p)
        out = zeros(length(p),size(x,2))
        for idx in 1:size(x,2)
            out[:,idx] .= _dfdp(@view(x[:,idx]),p)
        end
        out
    end
end

@ChrisRackauckas
Copy link
Member

But the first is computing incorrect values right? I think it's best to figure out why, instead of trying to just avoid the issue.

@agerlach
Copy link
Collaborator Author

agerlach commented Sep 4, 2020

If you call

dfdp = function (x,p)
    _,back = Zygote.pullback(p->prob.f(x,p),p)
    back(y)[1]
end

directly for batch >0 you will get incorrect answers. I seems like the issue is the size of y. It is odd, in some of the tests it will report the wrong answer, but in others it will error about dimensions. I think the difference between the two depended on if it was inplace or not.

@agerlach
Copy link
Collaborator Author

agerlach commented Sep 4, 2020

Here is a MWE showing the issue

using Zygote
w(x,p) = x[1,:]*p[1] + x[2,:]*p[2]
batch = 5
_,back  = Zygote.pullback(p->w(repeat([1,2],inner=(1,batch)),p),[3 2])
back([1])[1] #returns [5, 10]. but we need behaviour like [1 1 1 1 1; 2 2 2 2 2]

dfdp = function (x,p)
out = zeros(length(p),size(x,2))
for idx in 1:size(x,2)
_,back = Zygote.pullback(p->prob.f(@view(x[:,idx]),p),p)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still feels like a hack, but the pullback can get taken out of the loop w/o the view and then loop through back accordingly.

Using MWE:

using Zygote
w(x,p) = x[1,:]*p[1] + x[2,:]*p[2]
batch = 5
_,back  = Zygote.pullback(p->w(repeat([1,2],inner=(1,batch)),p),[3 2])

out = zeros(2, batch )
for idx in 1:batch
    z = zeros(batch)
    z[idx] = 1
    out[:,idx] = back(z)[1]
end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't need to pull back the pieces separately though: this is pullback of the identity matrix of a sparse Jacobian which is the matrix coloring problem that has a single color solution.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using Zygote
w(x,p) = x[1,:]*p[1] + x[2,:]*p[2]
batch = 5
_,back  = Zygote.pullback(p->w(repeat([1,2],inner=(1,batch)),p),[3 2])

out = zeros(2, batch )
z = zeros(batch)
for idx in 1:batch
    z[idx] = 1
    out[:,idx] = back(z)[1]
    z[idx] = 0
end
out

is how I'd do it. It's not a hack but the right way to do it, because you have to pull back the basis elements if the function isn't sparse, and here it's not sparse in p. This keeps the primal calculation down to a single time and is the equivalent to the forwarddiff chunk version.

@agerlach
Copy link
Collaborator Author

agerlach commented Sep 4, 2020

This should reduce the primal calculation as we discussed for in and out of place. Currently crashes for nout > 1

@ChrisRackauckas ChrisRackauckas merged commit 2fa4a7d into SciML:master Sep 4, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Zygote producing wrong result for ndim >1 & batch > 0
2 participants