Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve adjoint for product and zip #1489

Merged
merged 16 commits into from
Jan 19, 2024
Merged

Improve adjoint for product and zip #1489

merged 16 commits into from
Jan 19, 2024

Conversation

lxvm
Copy link
Contributor

@lxvm lxvm commented Jan 2, 2024

Hi,

I've returned to my first contribution in #1170 since I noticed I couldn't differentiate w.r.t Iterator.products that have a number as an iterator. This pr adds a test and fixes the issue while also improving the inferrability of the adjoint.

PR Checklist

  • Tests are added
  • Documentation, if applicable

@lxvm
Copy link
Contributor Author

lxvm commented Jan 3, 2024

Update: since we have adjoints for product and collect, I added an adjoint for collect(product()) intended to work with this example

using Zygote, Test
@test Zygote.gradient(x -> sum(broadcast(prod, Iterators.product(x .^ 2,x))), ones(4)) == (3*4ones(4),)

src/lib/array.jl Outdated Show resolved Hide resolved
Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mostly LGTM. Are you able to add tests which check that one or more of the gradient computation, pullback or productfunc are type stable? With the effort put into this, we should make sure inference keeps working!

src/lib/array.jl Outdated Show resolved Hide resolved
@lxvm
Copy link
Contributor Author

lxvm commented Jan 16, 2024

Sure, I'll try and get a test written soon, although testing inference can be tricky. Ideally Test.@inferred would work well, but it also has to pick up a regression and I'll try to figure it out tomorrow.

Additionally, I can update the Iterators.zip adjoints to match the product iterator

@lxvm
Copy link
Contributor Author

lxvm commented Jan 16, 2024

I've finished adding inference and correctness tests for product and zip. For zip, some care has to be taken for Numbers, since they are iterable. It looks like the CI errors are unrelated. Does this look good?

Update: In general, projecting onto the input iterator type is unlikely to work for custom iterators, take for example a range as seen here:

julia> Zygote.gradient(x -> sum(prod, x), 1:5)
([1.0, 1.0, 1.0, 1.0, 1.0],)

julia> Zygote.gradient(x -> sum(prod, zip(x)), 1:5)
(nothing,)

Any thoughts on how to improve handling this?

Correction: projection is not the reason the example above gives nothing, but I'm not sure what is (can open a separate issue). Still, shouldn't the adjoints for the iterators be able to handle the projection themselves, as explained here?

@ToucheSir
Copy link
Member

Although it does use ChainRules' projection machinery sometimes, Zygote overall doesn't do projection quite the same way because it predates ChainRules. The legacy projection machinery we do have can be rather inconsistent. In this particular case however, it looks like the input type is causing null gradients?

julia> gradient(x -> sum(prod, zip(x)), collect(1:5))
([1.0, 1.0, 1.0, 1.0, 1.0],)

julia> gradient(x -> sum(prod, zip(x)), 1.0:5.0)
([1.0, 1.0, 1.0, 1.0, 1.0],)

It makes sense that an integer range would be considered non-differentiable, but it would be good to confirm Zygote is doing this for the right reason and not because of some bug. Either way, if you can't figure out zip easily I'd just leave it for a follow-up PR and we can try to get the product changes in first.

@lxvm
Copy link
Contributor Author

lxvm commented Jan 17, 2024

Thank you for the context about projection in Zygote. I'm happy with keeping it as is to have this pr be as non-breaking as possible.

Otherwise the work on zip is done and I added tests that are equivalent to those for product. I did switch the adjoint for the constructor Iterators.Zip to one for the function Iterators.zip and I'm not sure if something depended especially on the former since there were previously no tests for it.

As for the observation of null gradients for integer ranges, it appears to have nothing to do with zip and everything to do with iteration. Here are some more cases

julia> Zygote.gradient(x -> sum(prod, Iterators.product(x)), 1:5)
(nothing,)

julia> Zygote.gradient(x -> sum(prod, Iterators.map(identity, x)), 1:5)
(nothing,)

julia> Zygote.gradient(x -> sum(prod, Iterators.take(x,5)), 1:5)
(nothing,)

I'd have to understand where the decision is being made, but I think it's safe to leave it to a follow-up.

@ToucheSir
Copy link
Member

ToucheSir commented Jan 17, 2024

It turns out the answer is easier than I thought:

@adjoint iterate(r::UnitRange, i...) = iterate(r, i...), _ -> nothing

This may have been for supporting for x in 1:N ....

@mcabbott
Copy link
Member

Zygote often throws away gradients of a UnitRange, here:

@adjoint iterate(r::UnitRange, i...) = iterate(r, i...), _ -> nothing

It's not enforced by projection, so things that hit other rules such as gradient(x -> sum(abs2, x), 1:5) don't give nothing.

I did switch the adjoint for the constructor Iterators.Zip to one for the function Iterators.zip and I'm not sure if something depended especially on the former since there were previously no tests for it.

I have no memory of why, but when initially writing these rules, attaching them to the uppercase constructor not the lowercase function somehow made more cases work. There are tests here but fewer than I thought.

@lxvm lxvm changed the title Improve adjoint for product Improve adjoint for product and zip Jan 17, 2024
test/lib/array.jl Outdated Show resolved Hide resolved
@lxvm
Copy link
Contributor Author

lxvm commented Jan 17, 2024

I've rebased this branch onto master and resolved the last issue I was concerned about. Looks like the CI is mostly good, although I'm not sure if the DynamicPPL failure is related

Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, a couple of things to touch up before merging.

src/lib/array.jl Outdated Show resolved Hide resolved
src/lib/array.jl Outdated
@@ -169,8 +170,11 @@ _reverse(x::Symmetric) = Symmetric(_reverse(x.data), x.uplo == 'U' ? :L : :U)
# So we keep axes(x) to restore gradient dx to its full length & correct shape.
_tryaxes(x) = axes(x)
_tryaxes(x::Tuple) = Val(length(x))
_restore(dx, ax::Tuple) = axes(dx) == ax ? dx : reshape(vcat(dx, falses(prod(length, ax) - length(dx))), ax)
_tryaxes(::Number) = Val(-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of a Val, maybe nothing or missing would be a more appropriate sentinel value here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, or even the number itself

src/lib/array.jl Outdated Show resolved Hide resolved
src/lib/array.jl Outdated Show resolved Hide resolved
lxvm and others added 2 commits January 19, 2024 13:20
Remove `Val` from `ntuple`s where constant propagation occurs

Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>
@ToucheSir
Copy link
Member

This is a great contribution for a tricky set of rules, thanks @lxvm !

@ToucheSir ToucheSir merged commit 46477ee into FluxML:master Jan 19, 2024
11 of 13 checks passed
@lxvm lxvm deleted the product branch January 19, 2024 22:17
@lxvm
Copy link
Contributor Author

lxvm commented Jan 19, 2024

Thanks to everyone for the helpful support!

@lxvm lxvm restored the product branch January 19, 2024 23:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants