Skip to content

Add max_context_length parameter to HookedTransformer#491

Closed
collingray wants to merge 3 commits intoTransformerLensOrg:mainfrom
collingray:max_context_size
Closed

Add max_context_length parameter to HookedTransformer#491
collingray wants to merge 3 commits intoTransformerLensOrg:mainfrom
collingray:max_context_size

Conversation

@collingray
Copy link
Copy Markdown
Contributor

@collingray collingray commented Jan 24, 2024

Description

Adds a max_context_length parameter to HookedTransformer.from_pretrained, HookedTransformer.from_pretrained_no_processing and loading_from_pretrained.get_pretrained_model_config, which caps the context length. By default, it is set to 2048 for from_pretrained, and None for the others (which leaves it uncapped).

Due to attention masks being stored ahead of time, memory usage grows with the square of context length, leading to some models (e.g. mistral) being unable to fit on consumer GPUs.

Could potentially be breaking if the longer context was relied on.

Fixes #490

Related to #479

Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

device: Optional[str] = None,
n_devices: int = 1,
default_prepend_bos: bool = True,
max_context_length: Optional[int] = None,
Copy link
Copy Markdown
Contributor Author

@collingray collingray Jan 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not positive that this is the correct default to use here.

My thinking is that "get_pretrained_model_config" implies that it will return the config mostly unaltered

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would call it override_default_context_length maybe? That has the problem of sounding like a Boolean, but combined with the type being Optional[int] I think it's pretty obvious what's going on

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My reasoning for max vs. override was that max makes it clear that the context length will only be changed if its too big, whereas there is some ambiguity about whether override would always set it to a fixed value.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps override_max_context_length would actually be best

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

override_max_context_length sounds good to me

@collingray collingray marked this pull request as ready for review January 24, 2024 00:50
other HuggingFace functions when compatible. For some models or arguments it doesn't
work, especially for models that are not internally loaded with HuggingFace's
from_pretrained (e.g. SoLU models).
max_context_length: The maximum context length to use for the model. Defaults to 2048. Can be set to
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a complex parameter, so I think the docstring should be more detailed. I would rephrase the comment to something like:

This allows us to override the TransformerLens default max context length for the model. TransformerLens may default to a smaller max context length than a model was trained with, you can use this parameter to restore a higher max if needed for your use case. Note that this will only work with positional encodings like rotary and sinusoidal that support multiple max context lengths, it will not work with absolute position encodings. If you set the max context to be larger than what the model was trained with, it is unlikely that it can make full use of the context.
For example, Mistral 7B was trained with a 32K context length, but TransformerLens defaults to 2K, because each attention layer has a n_ctx x n_ctx attention mask attached, which can get very memory intensive. If you need the full context, you can set this to a higher value.

"d_mlp": hf_config.intermediate_size // 2,
"n_layers": hf_config.num_hidden_layers,
"n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big
"n_ctx": hf_config.seq_length,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you change this?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And why didn't you change the Mistral entry?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thinking was to separate the parsing of the configs from the filtering of them. Its functionally the same, the filter is just applied at a different stage

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. I lean towards doing it here? My reasoning is that, if someone wants to understand the Qwen config, they can just come here and read this block of code. If there's later code that adjusts it, that's hard to notice, and may be actively misleading, because they'll see 32K here and expect that to end up in the model.

A counter-argument is that adding an "if n_ctx > 4K and override_max_ctx is None: n_ctx=4K" statement can be done once and work for all present and future models, rather than needing people to notice each time. But on net I still prefer all the "understanding Qwen config" code to be in one place

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats a fair point, another option would be to add a comment mentioning that it may be overridden by get_pretrained_model_config

I'm fine with either, lmk which you prefer

Copy link
Copy Markdown
Contributor Author

@collingray collingray Jan 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I forgot the other reason I went with that approach, which is that someone using the package not from source would have to directly modify the config in order to change n_ctx.

Though this doesn't really matter if you can also extend n_ctx with the parameter


# If max_context_length is specified, use it to cap n_ctx
if max_context_length and "n_ctx" in cfg_dict:
cfg_dict["n_ctx"] = min(cfg_dict["n_ctx"], max_context_length)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO if the user sets max_context_length higher than n_ctx, we should allow it to increase n_ctx

Copy link
Copy Markdown
Contributor Author

@collingray collingray Jan 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this be possible, assuming that n_ctx is already the maximum context length of the model?

edit: oh I see what you're saying, yeah I can add this

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, rotary models allow you to increase n_ctx beyond what a model was trained with, it may just go completely off the rails. Generally what happens, I think, is that it works fine, but doesn't get any better at predicting the 64Kth token than the 32Kth token (but is better at predicting 32Kth than 16Kth)

Copy link
Copy Markdown
Contributor Author

@collingray collingray Jan 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may actually make sense to split it into multiple parameters for this (perhaps override_max_context_length and extend_context_length), in order to still use a default maximum n_ctx, since if it always sets it directly to the provided value then smaller models will get raised up to that n_ctx without this being clear to the user

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's worth having two parameters for such a subtle distinction. IMO if you use this parameter you should know what you're doing. Maybe have it output a warning?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had set the default for it at 2048 in HookedTransformer.from_pretrained, the thinking being that the people using larger contexts likely know the library better, but new users may get tripped up trying to use mistral and having to change a flag to get it to load

A warning for when n_ctx is being extended is a good idea though

Copy link
Copy Markdown
Collaborator

@neelnanda-io neelnanda-io left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, I added some minor comments

@neelnanda-io
Copy link
Copy Markdown
Collaborator

BTW, thanks a lot for all the PRs recently!

@collingray
Copy link
Copy Markdown
Contributor Author

BTW, thanks a lot for all the PRs recently!

I'm glad to help, its a very useful library

Also lmk if there are any other particular issues that would be good to solve

@andyrdt
Copy link
Copy Markdown
Contributor

andyrdt commented Jan 24, 2024

Will this notion of max_context_length be important after we fix the memory-inefficient implementation of attention masking?

I personally feel like we should just fix the attention mask implementation, and save adding another parameter (that would become obsolete).

@collingray
Copy link
Copy Markdown
Contributor Author

Will this notion of max_context_length be important after we fix the memory-inefficient implementation of attention masking?

I personally feel like we should just fix the attention mask implementation, and save adding another parameter (that would become obsolete).

No I don't think it would be needed after (although the option to extend the context length could be), my original intention was for it to be a quick fix for mistrals memory usage.

Perhaps it would be a better idea to just directly lower it in the config and then work on the attention mask memory issue.

@collingray
Copy link
Copy Markdown
Contributor Author

Closing this in favor of #493, its perhaps worth adding the option to extend the context length, but there seems to be no need to shorten it given how little memory it should use

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Proposal] Change Mistral's config to reduce context size from 32k to 4k

3 participants