Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add map/broadcast/algebra/iteration/dict interface for Grads #902

Merged
merged 17 commits into from
Feb 20, 2021
Merged

Conversation

CarloLucibello
Copy link
Member

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks for taking care of this! My comment is that since Grads behaves like a Dict, I think we should limit this to operations involving only Grads (where we can verify the key set equality completely).

src/compiler/interface.jl Outdated Show resolved Hide resolved
src/compiler/interface.jl Outdated Show resolved Hide resolved
@CarloLucibello
Copy link
Member Author

I'm trying to figure out how to handle nothing

@CarloLucibello
Copy link
Member Author

I'm trying to figure out how to handle nothing

I decided to materialize nothing into a zeros array when involved in map

src/compiler/interface.jl Outdated Show resolved Hide resolved
docs/src/utils.md Outdated Show resolved Hide resolved
docs/src/utils.md Outdated Show resolved Hide resolved
docs/src/utils.md Outdated Show resolved Hide resolved
src/compiler/interface.jl Outdated Show resolved Hide resolved
src/compiler/interface.jl Outdated Show resolved Hide resolved
@CarloLucibello CarloLucibello changed the title add map/broadcast/algebra for Grads add map/broadcast/algebra/iteration/dict interface for Grads Feb 20, 2021
@CarloLucibello
Copy link
Member Author

ok, I think I'm happy with the design and implementation, should be ready

@oxinabox
Copy link
Member

There is an argument to be made that map should get the key and the value,
and that broadcast should just get the value.

Broadcast just seeing value would make sense to me.
e.g. g .= 2 .* (a .+ b)
Because one reasonably will want to broadcast against operations on numbers (which have no key).

But also access to keys is useful.
When writing an earlier comment i wanted to suggest the following as a usecase and relized i couldn't:
Consider learning rate scaling. Which used to be a big deal for deep belief networks.
Where Gaussian Layers should have 0.01 times the learning rate used for Bernoulli layers.

scaled_g = map(g) do k, v
    if is_gaussian_layer(k)
        0.01*v
    else
        v
    end
end

This wouldn't quite match what @darsnack proposed here FluxML/Flux.jl#707 (comment)
but i think this distinction might make sense.

Although: NamedTuples (and ChainRules.Composites) do just map with values.
So this wouldn't be consistent with that.
But OTOH, if this does workout: map with pairs, and broadcast with values, then this could be a good place to test it, and it could resolve the stalemate that Base.AbstrctDict arguments are having.
(and that resolution would also result in map(pairs(Dict)) do k,v working, which feels right to me.)

@CarloLucibello
Copy link
Member Author

CarloLucibello commented Feb 20, 2021

Consider learning rate scaling. Which used to be a big deal for deep belief networks.
Where Gaussian Layers should have 0.01 times the learning rate used for Bernoulli layers.

Consider that keys are parameters, not layers, so I don't think you can detect if is_gaussian_layer in your example.

The use cases I can think of, such as clamping, make only use of values, so value-based map and iteration seem more convenient, and in any case pairs are just a few keystrokes away. Another point in favor of values is that
https://github.com/andyferris/Dictionaries.jl
took that stance, and that seems a really well-thought library.

I'm not strongly against pairs iteration and map though, just seems less convenient to me but I'm sure it would work just fine.

Maybe @andyferris could give some advice here

src/compiler/interface.jl Outdated Show resolved Hide resolved
@oxinabox
Copy link
Member

Ok, I am sold, I trust you have thought this through.

@darsnack
Copy link
Member

I feel like it would be surprising if map and broadcast did different things (though there is no rule that they should be the same). In this case, I believe the proposed map(f, pairs(grads)) is the clearest version of the intent of the code. map(f, grads) operating on pairs would become confusing to me when we go to multiple grads: map(f, gs1, gs2, ...). You could make sure the key sets are equal and only have one key at the start, but that's too confusing imo.

The main issue for me though is that the ordering of the pairs is not guaranteed across Grads. As a user, if two Grads has the same parameters, then I expect the values associated parameter A to be matched regardless of storage order. This is hard to guarantee in a way that's communicated to the user. For example, for gs1 .+ gs2, we could just pick the order of the first argument. But for gs1 .+ gs2 .- [rand(size(p)) for p in params(m)], how does one know what order the generator should be in? Broadcasting on values allows us to play nicer with the rest of the broadcasted expression, but I think you lose the ordering guarantee which makes things murky.

One thing I will suggest is that maybe we extend the current binary operation broadcast between a number and a Grads to be any binary function? Would be surprising that I can do 2 .* grads but not min.(grads, 1e-4) for example.

@darsnack
Copy link
Member

One other option is to allow broadcasting not just with other Grads but any Pair. We can just check for key set equality like we do. This would allow for broadcasting with regular arrays in a sensible way. But we can save this for another PR if we want.

@CarloLucibello
Copy link
Member Author

CarloLucibello commented Feb 20, 2021

Would be surprising that I can do 2 .* grads but not min.(grads, 1e-4) for example.

This a bit tricky, because min would need two levels of broadcasting two work, kind of min... While 2 * grads[p] is well defined, min(grads[p], 1e-4) is not.

One other option is to allow broadcasting not just with other Grads but any Pair. We can just check for key set equality like we do. This would allow for broadcasting with regular arrays in a sensible way. But we can save this for another PR if we want.

handling Pairs seems fine, I'll try to implement it

Edit. Actually, we can extend support to AbstractDict

@darsnack
Copy link
Member

This a bit tricky, because min would need two levels of broadcasting two work, kind of min... While 2 * grads[p] is well defined, min(grads[p], 1e-4) is not.

My bad, that was a poor example. I just meant that if f(x::Number, y::Number) is defined, it would be surprising if 2 .* grads worked but f.(grads, 2) did not. I think if you just change the binary op def to:

broadcasted(f, a::Number, gs::Grads) = map(x -> f(a, x), gs)
broadcasted(f, gs::Grads, a::Number) = map(x -> f(x, a), gs)

it would work. Probably people only care about * but this is easy enough to add safely.

@CarloLucibello
Copy link
Member Author

broadcasted(f, a::Number, gs::Grads) = map(x -> f(a, x), gs)
broadcasted(f, gs::Grads, a::Number) = map(x -> f(x, a), gs)

@darsnack should we generalize this to Union{Number, AbstractArray}?

@@ -185,7 +188,7 @@ function copy!(x::AbstractVector, gs::Grads)
x
end

broadcasted(f, gss::Grads...) = map(f, gss...)
broadcasted(f, gs::Grads, gss::ADictOrGrads...) = map(f, gs, gss...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would want this

Suggested change
broadcasted(f, gs::Grads, gss::ADictOrGrads...) = map(f, gs, gss...)
broadcasted(f, gs::ADictOrGrads, gss::ADictOrGrads...) = map(f, gs, gss...)

But that would overload broadcasting for the case where gs and gss are just Dicts (no Grads). Don't know a quick way around it but unfortunate that the first arg must be a Grad.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, I don't know either, unless we start adding a bunch of definitions like

broadcasted(f, gs1::AbstractDict, gs2::Grads, gss::ADictOrGrads...) = map(f, gs1, gs2, gss...)

and the corresponding ones for map, but this seems pretty annoying and I would leave it to future PRs

Copy link
Member

@darsnack darsnack Feb 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is broadcasted(f, gss::Vararg{Union{AbstractDict, Grads}}) different from broadcasted(f, gss::Vararg{<:Union{AbstractDict, Grads}})? I think the first forces the union type. Let me double check.

It does not 😞. Yeah I think leave this for another PR unless someone knows a quick fix.

@darsnack
Copy link
Member

should we generalize this to Union{Number, AbstractArray}?

We could. It seems unlikely that every parameter in the Grads has the same shape, but I think semantically this is safe to do. I think now that we have the ability to support different shapes through broadcasting with AbstractDict, I don't have any gripes. We just want to make clear in the docs that if you need to broadcast with a vector of arrays (cause the shapes are different), you need to use [get_array(p) for p in keys(grads)].

@CarloLucibello
Copy link
Member Author

you need to use [get_array(p) for p in keys(grads)].

actually, this case is problematic, since if a = [get_array(p) for p in keys(grads)] then gs .+ a would imply gs[p] + a which is wrong. Given this scenario, we better disallow broadcast with arrays.
That usage is achieved with Dict(p => get_array(p) for p in keys(grads)).

@darsnack
Copy link
Member

darsnack commented Feb 20, 2021

I totally agree and just forgot to write the p => in my comment.

@CarloLucibello
Copy link
Member Author

In the end, I allowed broadcasted binary operations involved numerical types.

I think we are done. One little regret is that we now support

gs .+ Dict(w => rand(2), b => nothing)

but it would nice to relax the issetequal assumption on the keys and allow

gs .+ Dict(w => rand(2))

Doesn't seem easy to do with current map implementation ... material for future PRs

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks really great! Thanks for the effort.

docs/src/utils.md Show resolved Hide resolved
src/compiler/interface.jl Outdated Show resolved Hide resolved
src/compiler/interface.jl Show resolved Hide resolved
test/interface.jl Show resolved Hide resolved
CarloLucibello and others added 3 commits February 20, 2021 20:12
Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
@CarloLucibello
Copy link
Member Author

bors r+

@bors
Copy link
Contributor

bors bot commented Feb 20, 2021

Build succeeded:

@bors bors bot merged commit 6b89a06 into master Feb 20, 2021
@bors bors bot deleted the cl/algebra branch February 20, 2021 20:56
@andyferris
Copy link

andyferris commented Feb 21, 2021

I’m not sure I can comment specifically on Grad but these are really interesting discussions.

I particularly agree with:

In this case, I believe the proposed map(f, pairs(grads)) is the clearest version of the intent of the code. map(f, grads) operating on pairs would become confusing to me when we go to multiple grads: map(f, gs1, gs2, ...).

This is a really good point. The other nice one under default values iteration is map(f, keys(d), d, d2, ...) works well for when you want the key, while on the other hand f is a bit tortured under map(f, pairs(d1), pairs(d2), ...).

As to differences between map and broadcast is that in my personal interpretation map is about matching up on iteration and broadcast is about (more flexibly) matching up on indices. I think it would be confusing if the elements they got differed, though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

accumulate gradient with the new gradient API?
6 participants