Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle tied weights in update! #42

Closed
wants to merge 11 commits into from
Closed

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Jan 30, 2022

This alters setup to record the "address" of any tied weights, and then update! to first add the gradient of the second to the first (of each pair), then update as normal, and finally re-create the tie in the updated model.

I've tried to match the existing update! as much as possible. The format in which the "address" is stored is just a tuple of property names. The function to "pick out" a gradient component based on this is easy. The one to "place back" the modified one is trickier, as it needs to re-create missing branches -- and not just the minimal branch, but all the other empty fields. So it walks the model in parallel to the gradient, a lot like the existing update!. This isn't something that e.g. Setfield.jl thinks about.

Maybe this could all be abstracted away somehow, moved up into Functors? We seem to need quite a few patterns which aren't like fmap.

Surely the "address" could be stored in a more compile-away-able format, alla Setfield.jl. None of this is type stable, and it takes a few μs. I think that's true of everything in Functors.jl too. And it might even be desirable, for startup speed with deeply nested models.

This is on top of #41. Uses FluxML/Functors.jl#33, which is copied in here for now.

@mcabbott mcabbott force-pushed the duplicated branch 3 times, most recently from c19e259 to d9f4d80 Compare January 30, 2022 01:53
@codecov-commenter

This comment has been minimized.

@ToucheSir
Copy link
Member

Following from our discussion on design, how does composition of states work with this approach? Since ties, pick and place work on absolute addresses from the tree root, it's not clear to me how to keep things in sync when optimizing a subset (e.g. gradual unfreezing) or a superset (e.g. adding a classifier head and fine-tuning) of the original state.

@mcabbott
Copy link
Member Author

Good question. These are all from the root. You can't accidentally take a branch of the state & model and lose them, since the outermost Tree type will mean the types don't line up. If instead you pasted these, whole, into some larger model's pair, I bet they would just work? Unless doing that created further ties which neither half saw before.

Explicitly messing with the state tree with your bare hands seems obviously at-own-risk. Better tools for explicitly feezing etc... I guess these would need to interact with this, re-run the check for ties which setup does here?

src/interface.jl Outdated
Comment on lines 15 to 26
function setup(rule, x; ties = Pair[], cache = IdDict())
tree = _setup(rule, x, (); ties, cache)
isempty(ties) ? tree : Tree(ties, tree)
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly accidentally, this method in fact lets you construct ties where they aren't inferred:

julia> using StaticArrays

julia> st = Optimisers.setup(Descent(-0.01), (SA[1,2], SA[3,4]), ties = [(2,) => (1,)])
Optimisers.Tree([(2,) => (1,)], (Leaf(Descent{Float64}(-0.01), nothing), Leaf(Descent{Float64}(-0.01), nothing)))

julia> Optimisers.update!(st, (SA[1,2], SA[3,4]), (SA[5,6], SA[7,8]))
(Optimisers.Tree([(2,) => (1,)], (Leaf(Descent{Float64}(-0.01), nothing), Leaf(Descent{Float64}(-0.01), nothing))), ([1.12, 2.14], [1.12, 2.14]))

@darsnack
Copy link
Member

darsnack commented Jan 31, 2022

Had the time to go through this and FluxML/Functors.jl#33. I'm happy with both PRs. I had one bike shed comment, but after reading the discussion, I think it's relevant too.

Should Tied be something more broad like InitializedTree. Something to indicate that this is a tree + auxiliary information? In the future we might consider adding different kinds of auxiliary information. Going over the two cases above:

  • Doing a superset like backbone + classifier is probably simple enough to just splat the addresses correctly.
  • More complicated is gradual (un)freezing where branches come in and out. In this case, I think it's easier to keep the trees complete and include auxiliary information that says "this branch is frozen, don't recurse or update." This also opens this piece of the design up so that other packages can write iterative freezing algorithms and all they have to do is use utilities like fsimilar, etc.

@ToucheSir
Copy link
Member

I have to admit the implementation is a tad brain-warping, though all that complexity is necessary so there's not much more to say. I agree with Kyle that we should give this a go. If some unforeseen case rears its ugly head then hey, Optimisers.jl isn't anywhere near stable yet and the tied logic is sufficiently decoupled from the core Optimisers functionality.

@mcabbott
Copy link
Member Author

mcabbott commented Feb 1, 2022

Re interaction with other things:

  • Storing this information in some struct at the base of the state tree alters its shape, and means that anything else walking this tree needs to know about this storage struct, the one which doesn't have a parallel in the model. Perhaps that's an argument against this way of doing it. But storing tie information (say) at the first leaf of each pair instead, seems tricky too --- would the address stored there still be relative to the root? That seems more dangerous, as grafting branches would make it wrong.

  • To freeze some weights there's a similar question of where to store the information. It could also be stored at the base (either expanding Tied to have two jobs, or by composing two such things, each with their own update! dispatch). Such base storage would need updating on any graft / backbone + classifier operation. It could also be stored at the leaf, e.g. by adding a "freeze" flag to Leaf. Or it could be stored by inserting a struct at the base of the sub-tree which is entirely frozen. My guess is that per-leaf sounds the simplest in fact. Unlike ties you aren't forced to think about the "address" of particular leaves.

Edit: #49 is a sketch of how freezing at the leaf would look.

@ToucheSir
Copy link
Member

Anything you want to add here before we consider letting users have at it?

@mcabbott
Copy link
Member Author

mcabbott commented Feb 9, 2022

src/interface.jl Outdated
Comment on lines 15 to 21
function setup(rule, x; ties = Pair[], cache = IdDict())
tree = _setup(rule, x, (); ties, cache)
isempty(ties) ? tree : Tied(ties, tree)
end

function _setup(rule, x, addr; ties, cache)
usecache = !isbits(x) && cache !== false
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At present this does not use the cache for isbits things, and allows cache=false to disable it completely. Ideally this would probably be harmonised with Functors.

@ToucheSir
Copy link
Member

A thought I had today was whether we couldn't obviate the need for Tied by making Leaf a mutable struct. Then tying two or more params could be represented by sharing the same leaf instance in the tree. This would not solve the composition problem, but it could reduce quite a bit of internal complexity and would work well with #49.

@mcabbott
Copy link
Member Author

A thought I had today was whether we couldn't obviate the need for Tied by making Leaf a mutable struct. Then tying two or more params could be represented by sharing the same leaf instance in the tree.

I'm not sure I see how this would work. Certainly setup could produce such a thing, with a cache alla Functors. But on the walk for update!, you need to know when you get to the first of a tied pair, so that you accumulate the two gradients and apply the rule once. To do this, update! would I think have to build something much like Tied as a first pass, before proceeding much like this.

@ToucheSir
Copy link
Member

The idea is that update! on a non-leaf runs two passes: first accumulate gradients into a cache, and then do the usual traversal with a visited set of leaves to avoid running rules twice. I tried writing some pseudocode for this, but it quickly metastasized into a full-blown branch. Will try to toss up a PR showing off the approach this week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
Development

Successfully merging this pull request may close these issues.

None yet

4 participants