Manually create repr for partial hooks #845
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
HookPoint.add_hook()contained the line:This PR instead proposes
This isn't an issue when the hook is a regular function, but the original code can be extremely expensive when hook is a partial function that contains arguments which may be large, since calling
partial.__repr__will call__repr__on all of its arguments. A very natural and common usage of hooks is to create a partial function which takes a cache object as an argument. Some of the demos in the repo do this, e.g.Exploratory_Analysis_Demo.ipynbhas:where
cacheis a gpt2-smallmodel.run_with_cache().I'll note that I think it's very unlikely that anyone would rely on a specific name of a pytorch-native hook which this affects. This would only break if a user did some kind of string search for one of the specific arguments of their partial function inside the raw
nn.Module._forward_hooks.I have not added tests as the use-case for accessing this name attribute when using a partial function is hyper-specific. But can do if I'm missing a reasonable use-case.
This change makes my code 3 times faster (I don't have a public version of it yet, but it does some model training and uses run_with_hooks twice in the forward pass).
Fixes #631
Type of change
Please delete options that are not relevant.
Checklist: