Skip to content

Conversation

@danbraunai
Copy link
Contributor

@danbraunai danbraunai commented Jan 28, 2025

Description

HookPoint.add_hook() contained the line:

full_hook.__name__ = (
    hook.__repr__()
)  # annotate the `full_hook` with the string representation of the `hook` function

This PR instead proposes

if isinstance(hook, partial):
    full_hook.__name__ = f"partial({hook.func.__repr__()},...)"
else:
    full_hook.__name__ = hook.__repr__()

This isn't an issue when the hook is a regular function, but the original code can be extremely expensive when hook is a partial function that contains arguments which may be large, since calling partial.__repr__ will call __repr__ on all of its arguments. A very natural and common usage of hooks is to create a partial function which takes a cache object as an argument. Some of the demos in the repo do this, e.g. Exploratory_Analysis_Demo.ipynb has:

hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("resid_pre", layer), hook_fn)],
            return_type="logits",
        )

where cache is a gpt2-small model.run_with_cache().

I'll note that I think it's very unlikely that anyone would rely on a specific name of a pytorch-native hook which this affects. This would only break if a user did some kind of string search for one of the specific arguments of their partial function inside the raw nn.Module._forward_hooks.

I have not added tests as the use-case for accessing this name attribute when using a partial function is hyper-specific. But can do if I'm missing a reasonable use-case.

This change makes my code 3 times faster (I don't have a public version of it yet, but it does some model training and uses run_with_hooks twice in the forward pass).

Fixes #631

Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

@ArthurConmy
Copy link
Collaborator

LGTM but, thoughts on this? For this CL, or a future CL.

Idea: show arg and kwarg names, separated by * if both are present.

# Annotate the `full_hook` with the string representation of the `hook` function:
if isinstance(hook, partial):
    # Get function name safely
    func_name = getattr(hook.func, '__name__', hook.func.__class__.__name__)
    
    # Always show args as keyword arguments by getting their names from the signature
    sig = inspect.signature(hook.func)
    param_names = list(sig.parameters.keys())
    
    # Format args as keyword assignments using parameter names
    args_kwstr = ', '.join(f'{param_names[i]}=...' for i in range(len(hook.args)))
    
    # Format normal kwargs
    kwargs_str = ', '.join(f'{k}=...' for k in hook.keywords.keys())
    
    # Combine with * separator if both present
    params_str = args_kwstr
    if args_kwstr and kwargs_str:
        params_str += f', *, {kwargs_str}'
    elif kwargs_str:
        params_str = kwargs_str
    full_hook.__name__ = f"partial({func_name}({params_str}))"
else:
    full_hook.__name__ = hook.__repr__()

@ArthurConmy ArthurConmy self-requested a review January 28, 2025 22:18
@danbraunai
Copy link
Contributor Author

Yeah good idea but I think getting the args might be overkill here: If the user knows it's a partial and gets the function name that is wrapped by the partial, they will know what the other argument names are just be the definition of the function. I don't think that can vary.

I'm also just unsure if anyone has ever accessed the __name__ of the underlying pytorch hook, but maybe I'm missing a use case there.

@bryce13950
Copy link
Collaborator

Thanks for getting to this! I am putting up a release today, and I will get this in there.

@bryce13950 bryce13950 changed the base branch from main to dev February 3, 2025 18:44
@bryce13950 bryce13950 merged commit 1a6cb6a into TransformerLensOrg:dev Feb 3, 2025
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Proposal] Remove the overhead caused by full_hook.__name__ = (hook.__repr__())?

4 participants