Add max_context_length parameter to HookedTransformer#491
Add max_context_length parameter to HookedTransformer#491collingray wants to merge 3 commits intoTransformerLensOrg:mainfrom
max_context_length parameter to HookedTransformer#491Conversation
| device: Optional[str] = None, | ||
| n_devices: int = 1, | ||
| default_prepend_bos: bool = True, | ||
| max_context_length: Optional[int] = None, |
There was a problem hiding this comment.
I'm not positive that this is the correct default to use here.
My thinking is that "get_pretrained_model_config" implies that it will return the config mostly unaltered
There was a problem hiding this comment.
I would call it override_default_context_length maybe? That has the problem of sounding like a Boolean, but combined with the type being Optional[int] I think it's pretty obvious what's going on
There was a problem hiding this comment.
My reasoning for max vs. override was that max makes it clear that the context length will only be changed if its too big, whereas there is some ambiguity about whether override would always set it to a fixed value.
There was a problem hiding this comment.
perhaps override_max_context_length would actually be best
There was a problem hiding this comment.
override_max_context_length sounds good to me
| other HuggingFace functions when compatible. For some models or arguments it doesn't | ||
| work, especially for models that are not internally loaded with HuggingFace's | ||
| from_pretrained (e.g. SoLU models). | ||
| max_context_length: The maximum context length to use for the model. Defaults to 2048. Can be set to |
There was a problem hiding this comment.
This is a complex parameter, so I think the docstring should be more detailed. I would rephrase the comment to something like:
This allows us to override the TransformerLens default max context length for the model. TransformerLens may default to a smaller max context length than a model was trained with, you can use this parameter to restore a higher max if needed for your use case. Note that this will only work with positional encodings like rotary and sinusoidal that support multiple max context lengths, it will not work with absolute position encodings. If you set the max context to be larger than what the model was trained with, it is unlikely that it can make full use of the context.
For example, Mistral 7B was trained with a 32K context length, but TransformerLens defaults to 2K, because each attention layer has a n_ctx x n_ctx attention mask attached, which can get very memory intensive. If you need the full context, you can set this to a higher value.
| "d_mlp": hf_config.intermediate_size // 2, | ||
| "n_layers": hf_config.num_hidden_layers, | ||
| "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big | ||
| "n_ctx": hf_config.seq_length, |
There was a problem hiding this comment.
Why did you change this?
There was a problem hiding this comment.
And why didn't you change the Mistral entry?
There was a problem hiding this comment.
My thinking was to separate the parsing of the configs from the filtering of them. Its functionally the same, the filter is just applied at a different stage
There was a problem hiding this comment.
Hmm. I lean towards doing it here? My reasoning is that, if someone wants to understand the Qwen config, they can just come here and read this block of code. If there's later code that adjusts it, that's hard to notice, and may be actively misleading, because they'll see 32K here and expect that to end up in the model.
A counter-argument is that adding an "if n_ctx > 4K and override_max_ctx is None: n_ctx=4K" statement can be done once and work for all present and future models, rather than needing people to notice each time. But on net I still prefer all the "understanding Qwen config" code to be in one place
There was a problem hiding this comment.
Thats a fair point, another option would be to add a comment mentioning that it may be overridden by get_pretrained_model_config
I'm fine with either, lmk which you prefer
There was a problem hiding this comment.
Oh I forgot the other reason I went with that approach, which is that someone using the package not from source would have to directly modify the config in order to change n_ctx.
Though this doesn't really matter if you can also extend n_ctx with the parameter
|
|
||
| # If max_context_length is specified, use it to cap n_ctx | ||
| if max_context_length and "n_ctx" in cfg_dict: | ||
| cfg_dict["n_ctx"] = min(cfg_dict["n_ctx"], max_context_length) |
There was a problem hiding this comment.
IMO if the user sets max_context_length higher than n_ctx, we should allow it to increase n_ctx
There was a problem hiding this comment.
Would this be possible, assuming that n_ctx is already the maximum context length of the model?
edit: oh I see what you're saying, yeah I can add this
There was a problem hiding this comment.
Yeah, rotary models allow you to increase n_ctx beyond what a model was trained with, it may just go completely off the rails. Generally what happens, I think, is that it works fine, but doesn't get any better at predicting the 64Kth token than the 32Kth token (but is better at predicting 32Kth than 16Kth)
There was a problem hiding this comment.
It may actually make sense to split it into multiple parameters for this (perhaps override_max_context_length and extend_context_length), in order to still use a default maximum n_ctx, since if it always sets it directly to the provided value then smaller models will get raised up to that n_ctx without this being clear to the user
There was a problem hiding this comment.
I don't think it's worth having two parameters for such a subtle distinction. IMO if you use this parameter you should know what you're doing. Maybe have it output a warning?
There was a problem hiding this comment.
I had set the default for it at 2048 in HookedTransformer.from_pretrained, the thinking being that the people using larger contexts likely know the library better, but new users may get tripped up trying to use mistral and having to change a flag to get it to load
A warning for when n_ctx is being extended is a good idea though
neelnanda-io
left a comment
There was a problem hiding this comment.
Overall looks good, I added some minor comments
|
BTW, thanks a lot for all the PRs recently! |
I'm glad to help, its a very useful library Also lmk if there are any other particular issues that would be good to solve |
|
Will this notion of I personally feel like we should just fix the attention mask implementation, and save adding another parameter (that would become obsolete). |
No I don't think it would be needed after (although the option to extend the context length could be), my original intention was for it to be a quick fix for mistrals memory usage. Perhaps it would be a better idea to just directly lower it in the config and then work on the attention mask memory issue. |
|
Closing this in favor of #493, its perhaps worth adding the option to extend the context length, but there seems to be no need to shorten it given how little memory it should use |
Description
Adds a
max_context_lengthparameter toHookedTransformer.from_pretrained,HookedTransformer.from_pretrained_no_processingandloading_from_pretrained.get_pretrained_model_config, which caps the context length. By default, it is set to2048forfrom_pretrained, andNonefor the others (which leaves it uncapped).Due to attention masks being stored ahead of time, memory usage grows with the square of context length, leading to some models (e.g. mistral) being unable to fit on consumer GPUs.
Could potentially be breaking if the longer context was relied on.
Fixes #490
Related to #479
Type of change
Please delete options that are not relevant.
Checklist: