Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Equivalent of register_forward_hook #218

Open
cpcdoy opened this issue Jun 20, 2020 · 5 comments
Open

Equivalent of register_forward_hook #218

cpcdoy opened this issue Jun 20, 2020 · 5 comments

Comments

@cpcdoy
Copy link

cpcdoy commented Jun 20, 2020

Hi,

I've been using hooks and especially register_forward_hook in PyTorch and wanted to be able to do the same in Rust.

Unless I missed something (I checked both torch-sys and tch-rs for a similar function), is there any way to emulate a hook using your API?

Thanks a lot

@cpcdoy cpcdoy changed the title Equivalent of register_forward_hook Equivalent of register_forward_hook Jun 20, 2020
@LaurentMazare
Copy link
Owner

I don't think there is any equivalent for this at the moment: the current api is mostly generated automatically from the declarations.yaml file and I haven't found anything related in it.
Do you know if this could be done with the C++ api? Also maybe giving an example of your typical use case for this could help us see if there is a way to achieve the same kind of thing in tch-rs.

@cpcdoy
Copy link
Author

cpcdoy commented Jun 21, 2020

So, I haven't used the C++ API before, so I'm not quite sure how to replicate that.

Here's how I've been using it in Python:

Let's say I have a DistilBert architecture:

...
      (0): TransformerBlock(
        (attention): MultiHeadSelfAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
...

Note: The architecture is cut to keep it brief

Now, let's say I want to run my model using the library in a standard way, like this:

model.encode('whatever')

But at the same time, by running the above line, I want to save the output of k_lin and q_lin that you can see above in the architecture summary.

The final goal is to store them in a list so that I can, for example for each layer, study each one of k_lin and q_lin's outputs independently without having to download the library's source code, modify it, make it output k_lin and q_lin and then maintain this modified version of the library on the side.
That doesn't seem like a good idea.

In PyTorch, you'd do it using hooks this way:

  1. Create a list to keep the layer activations:
name_list = ['k_lin', 'q_lin']
NUM_LAYERS = 6
activations = {i : {name_ : [] for name_ in name_list} for i in range(NUM_LAYERS)}
  1. Create a hook
def create_hook(layer_, name_):
    def hook(model, input_, output_):
        activations[layer_][name_].append(output_.detach())
    return hook

# Access the q_lin layer and register a hook
for i in range(NUM_LAYERS):
    model.transformer.layer[i].attention.q_lin.register_forward_hook(create_hook(i, 'q_lin'))
    model.transformer.layer[i].attention.k_lin.register_forward_hook(create_hook(i, 'k_lin'))
  1. Use the library normally and the activations list will contain the output tensor of the forward pass of q_lin:
model.encode('whatever')

Hope this helps

@cpcdoy
Copy link
Author

cpcdoy commented Jun 30, 2020

Just pinging the thread to see if there's any update on this

@NOBLES5E
Copy link

@LaurentMazare It seems that there is a register_hook function in c++ api: https://github.com/pytorch/pytorch/blob/115494b00bf31549aa5227068bd66a3da9de469b/test/cpp/api/autograd.cpp#L547. Not sure whether it can be ported to tch-rs though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants