-
-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Handle tied weights in update!
#42
Conversation
c19e259
to
d9f4d80
Compare
This comment has been minimized.
This comment has been minimized.
Following from our discussion on design, how does composition of states work with this approach? Since |
Good question. These are all from the root. You can't accidentally take a branch of the state & model and lose them, since the outermost Tree type will mean the types don't line up. If instead you pasted these, whole, into some larger model's pair, I bet they would just work? Unless doing that created further ties which neither half saw before. Explicitly messing with the state tree with your bare hands seems obviously at-own-risk. Better tools for explicitly feezing etc... I guess these would need to interact with this, re-run the check for ties which |
src/interface.jl
Outdated
function setup(rule, x; ties = Pair[], cache = IdDict()) | ||
tree = _setup(rule, x, (); ties, cache) | ||
isempty(ties) ? tree : Tree(ties, tree) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly accidentally, this method in fact lets you construct ties where they aren't inferred:
julia> using StaticArrays
julia> st = Optimisers.setup(Descent(-0.01), (SA[1,2], SA[3,4]), ties = [(2,) => (1,)])
Optimisers.Tree([(2,) => (1,)], (Leaf(Descent{Float64}(-0.01), nothing), Leaf(Descent{Float64}(-0.01), nothing)))
julia> Optimisers.update!(st, (SA[1,2], SA[3,4]), (SA[5,6], SA[7,8]))
(Optimisers.Tree([(2,) => (1,)], (Leaf(Descent{Float64}(-0.01), nothing), Leaf(Descent{Float64}(-0.01), nothing))), ([1.12, 2.14], [1.12, 2.14]))
Had the time to go through this and FluxML/Functors.jl#33. I'm happy with both PRs. I had one bike shed comment, but after reading the discussion, I think it's relevant too. Should
|
I have to admit the implementation is a tad brain-warping, though all that complexity is necessary so there's not much more to say. I agree with Kyle that we should give this a go. If some unforeseen case rears its ugly head then hey, Optimisers.jl isn't anywhere near stable yet and the tied logic is sufficiently decoupled from the core Optimisers functionality. |
Re interaction with other things:
Edit: #49 is a sketch of how freezing at the leaf would look. |
Anything you want to add here before we consider letting users have at it? |
|
src/interface.jl
Outdated
function setup(rule, x; ties = Pair[], cache = IdDict()) | ||
tree = _setup(rule, x, (); ties, cache) | ||
isempty(ties) ? tree : Tied(ties, tree) | ||
end | ||
|
||
function _setup(rule, x, addr; ties, cache) | ||
usecache = !isbits(x) && cache !== false |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At present this does not use the cache for isbits things, and allows cache=false
to disable it completely. Ideally this would probably be harmonised with Functors.
A thought I had today was whether we couldn't obviate the need for |
I'm not sure I see how this would work. Certainly |
The idea is that |
This alters
setup
to record the "address" of any tied weights, and thenupdate!
to first add the gradient of the second to the first (of each pair), then update as normal, and finally re-create the tie in the updated model.I've tried to match the existing
update!
as much as possible. The format in which the "address" is stored is just a tuple of property names. The function to "pick out" a gradient component based on this is easy. The one to "place back" the modified one is trickier, as it needs to re-create missing branches -- and not just the minimal branch, but all the other empty fields. So it walks the model in parallel to the gradient, a lot like the existingupdate!
. This isn't something that e.g. Setfield.jl thinks about.Maybe this could all be abstracted away somehow, moved up into Functors? We seem to need quite a few patterns which aren't like
fmap
.Surely the "address" could be stored in a more compile-away-able format, alla Setfield.jl. None of this is type stable, and it takes a few μs. I think that's true of everything in Functors.jl too. And it might even be desirable, for startup speed with deeply nested models.
This is on top of #41.Uses FluxML/Functors.jl#33, which is copied in here for now.