Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Soft prompts #231

Merged
merged 47 commits into from Dec 1, 2022
Merged

Soft prompts #231

merged 47 commits into from Dec 1, 2022

Conversation

dirkgr
Copy link
Member

@dirkgr dirkgr commented Mar 15, 2022

In code:

import transformers
t = transformers.AutoModel.from_pretrained("gpt2")
twp = make_prefix_transformer(t, prefix_length=3)

In config files:

{
    model: {
        type: "transformers::with_soft_prompt",
        prompt_length: 3,
        model: {
            type: "transformers::AutoModelForCausalLM::from_pretrained",
            pretrained_model_name_or_path: "gpt2"
        },
    }
}

Missing:

  • Tests
  • Docs
  • Try it with T5
  • A proper end-to-end training config that uses this
  • Add an easy way to make only the prefix trainable, and leave the rest of the weights alone

return model


Model.register("transformers::with_prefix")(make_prefix_transformer)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this will probably be difficult to change later, it's worth thinking about the terminology. Prefix tuning? Prompt tuning? Something else?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we'll call it with_soft_prompt?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idk. At least some people use the prompt vs prefix tuning distinction to refer to the shallow (input layer only) vs deep distinction. I have no strong preference, but worth thinking carefully about and maybe asking for wider opinions.

tango/integrations/transformers/prefix_transformer.py Outdated Show resolved Hide resolved
tango/integrations/transformers/prefix_transformer.py Outdated Show resolved Hide resolved
tango/integrations/transformers/prefix_transformer.py Outdated Show resolved Hide resolved
tango/integrations/transformers/prefix_transformer.py Outdated Show resolved Hide resolved
# Because PyTorch hooks don't support kwargs, we monkey patch the forward method 🙈
old_forward = model.forward

def new_forward(*args, **kwargs):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What made me turn away from monkeypatching in my code is when I noticed that it doesn't need/have a self, so there might be some fundamental differences between the old vs. new forward. If I were two years younger I probably would have voted for monkeypatching but the older me is less adventurous and worry more about safety. Go ahead if you're confident that this is safe, but at least I would suggest some sort of assertion to check the forward has not been monkeypatched (because if it had, the logic would be incorrect).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little uneasy about it, but I think it beats the alternatives. At the very least I want to see where it goes and where it falls down, if it does. Also, apparently there is movement on the PyTorch side to allow kwargs in hooks. When that comes true, we can do this properly.

As for your specific concern, this will work fine even if forward() has already been monkey patched before.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will it? Wouldn't be the patching happen multiple times at each recursion level?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the thing you pass into this function was monkey patched before, then old_forward ends up being the first level of monkey patching, and it will get called when we go one level down.

old_forward becomes part of the closure of new_forward. That's how the chain of forward methods is maintained.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. So it would be like this, no?

forward:  # monkey patch lvl 1
  patch_tensor
  forward:  # monkey patch lvl 2  
    patch_tensor
    forward  # original

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the tensors will be patched twice

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inputs and outputs should be patched twice. That is correct.

What won't work is that you can't call set_input_embeddings() twice with the way I have it here, because _WithPromptEmbedding reaches into the original embedding's internals.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, what I was worried about is the method being unintentionally called twice. I can't think of a case where it is intentionally called twice.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c8d1b86 should make it possible to stack two prompt-enabled transformers on top of each other.

I think it's important that we ensure this pattern works for other forms as well. What if we implement adapters the same way, and we want to run both at the same time? The whole point of trying for this "looks like a normal huggingface transformer" approach is that it should be easy to combine with other components that do the same thing.


result = old_forward(*args, **kwargs)

if isinstance(result, CausalLMOutputWithCrossAttentions):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comment for what this is doing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I have to go through and write docs and whatnot.

@dirkgr
Copy link
Member Author

dirkgr commented Mar 16, 2022

Oh no, I found a big problem with this. It doesn't work with past_key_values. Fix incoming.

@dirkgr
Copy link
Member Author

dirkgr commented Mar 16, 2022

This does not work for T5 at all 😭. I'm no longer sure this approach of patching the model will work. The huggingface generation code makes calls into the middle of their model, instead of always going through the forward() method. So patching forward() doesn't work. And patching forward() of an internal method breaks all sorts of assumptions that other parts of the code have about that forward() method.

@ZhaofengWu
Copy link

This does not work for T5 at all 😭. I'm no longer sure this approach of patching the model will work. The huggingface generation code makes calls into the middle of their model, instead of always going through the forward() method. So patching forward() doesn't work.

Is this problematic for generation only?

And patching forward() of an internal method breaks all sorts of assumptions that other parts of the code have about that forward() method.

This is what I was worrying about above.

@dirkgr
Copy link
Member Author

dirkgr commented Mar 16, 2022

Copying from Slack:

I can patch just the encoder for T5. Then the soft prompt has the opportunity to change how the rest of the prompt is encoded. But the encoded soft tokens are not part of the encoder output, and cannot be attended to by the decoder. @ZhaofengWu, is that important?

@dirkgr dirkgr changed the title Prefix tuning Soft prompts Mar 17, 2022
@dirkgr
Copy link
Member Author

dirkgr commented Mar 17, 2022

Just to resolve this chain of comments: I made it work with T5.

@dirkgr dirkgr requested a review from AkshitaB March 18, 2022 17:55
CHANGELOG.md Outdated
@@ -262,6 +262,7 @@ instead of `ModuleNotFound`.
- Added the "-n/--name" option to `tango run`. This option allows the user to give the run an arbitrary name.
- Added a convenience property `.workspace` to `Step` class that can be called from a step's `.run()` method to get the current `Workspace` being used.
- Gave `FromParams` objects (which includes all `Registrable` objects) the ability to version themselves.
- Added the `transformers::with_soft_prompt` integration, to make soft-prompted prefix transformers easy.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should move this up in the changelog.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)
)
r = random.Random(random_seed)
indices = torch.tensor(r.sample(range(5000), prompt_length))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does 5000 come from?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a number that Zhaofeng used in his code. He got it from some paper.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sure is a little weird. Maybe it should sample from the entire original embedding.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was originally used in https://arxiv.org/abs/2104.08691 and subsequently other papers such as https://arxiv.org/abs/2108.04106 and of course ours. The idea is to only use the representation of the top-5000 tokens.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd keep 5000 or at least have some flag to control this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the idea that the top 5000 most frequent tokens have received more training data and are therefore better?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's my understanding

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made it configurable, with a default of 5000.

@dirkgr dirkgr marked this pull request as ready for review November 29, 2022 15:03
@dirkgr
Copy link
Member Author

dirkgr commented Nov 30, 2022

This is ready for another review.

patch_tensor(kwargs, "labels")
patch_tensor(kwargs, "attention_mask", 1)
patch_tensor(kwargs, "token_type_ids")
patch_tensor_with_indices(kwargs, "position_ids", prompt_length)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, if the position ids are originally [0, 1, 2, 3, 4], they will now be [0, 1, 2, .. prompt_len-1, 0, 1, 2, 3, 4] ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's right. I could see it going the other way, but I think it's important that the output does not change in the case where the soft prompt is configured to do nothing. Also, if we offset the position ids, we would decrease the max length that the model can handle, which is uncomfortable.

@dirkgr dirkgr merged commit 73bfa86 into main Dec 1, 2022
@dirkgr dirkgr deleted the PrefixTuning branch December 1, 2022 18:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants