Add decoding time compression#138
Conversation
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
|
/ok to test a43fc19 |
|
Is this normal @maxjeblick ?
|
| logger.debug(f"Compressed Context Length: {cache.get_seq_length()}") | ||
|
|
||
| # Greedy decoding for each question | ||
| answers = [] |
There was a problem hiding this comment.
Here, we don't exit the context manager after prefilling. This may break kvzip press.
There was a problem hiding this comment.
A straightforward solution might be to have two context mangers
should_perform_prefill_compression = press is not None or not isinstance(press, (DecodingPress, PrefillDecodingPress)
with press(self.model) if not should_prefill_compression else contextlib.nullcontext():
(and a subsequent with block for answer generation). By this, we 100% ensure same control flow in case no decoding press is used.
There was a problem hiding this comment.
As you @maxjeblick pointed out KVZip is not supported because it is not a scorer press
There was a problem hiding this comment.
Ok sorry @maxjeblick I missed the problem, I tried the 2 context managers approach, wdyt ?
maxjeblick
left a comment
There was a problem hiding this comment.
Thanks a lot for working on decoding press and adapting the code!
Some comments:
- IMO, decoding press can be refactored a bit, in particulat cache handling (qunatized/non quantized) now appears in various presses and can be factored out.
- The notebook needs probably be rerun to produce a correct output (max generation length is too low). The generated text output can also be formatted nice for display
- The PR changes the pipeline logic, the
with presscontext manager now exists AFTER generation, not before. This will most likely cause kvzip press to not work any longer (it relies on the context manager to exit after prefilling).
| Target number of tokens to keep after compression. | ||
| hidden_states_buffer_size : int, default=128 | ||
| Maximum number of hidden states to keep before compression. Larger values use more GPU memory. | ||
| NoteSome presses don't need buffered hidden states and can set this to 0 to use only the |
| ) | ||
|
|
||
| cache_layer = cache.layers[module.layer_idx] | ||
| if isinstance(cache, QuantizedCache): |
There was a problem hiding this comment.
Thiis is a candidate for refactoring, as it appears also in the base presses.
There was a problem hiding this comment.
Moved to utils
| return output | ||
| # print(f"Adding hidden states to buffer: {hidden_states.shape}") | ||
| # Add current hidden states to buffer for this layer | ||
| self.hidden_states_buffer[layer_idx].append(hidden_states.detach().clone()) |
There was a problem hiding this comment.
Hidden states buffer might be longer than hidden_states_buffer_size.
IMO, the code makes sense; we may need to adapt the docstring to be more explict.
There was a problem hiding this comment.
We should already use torch.cat, s.t. self.hidden_states_buffer[layer_idx] is always a tensor; makes it more easy to use.
There was a problem hiding this comment.
We say that in the doctring: "Buffered hidden states from recent decoding steps (shape: [batch, buffer_len, hidden_dim])"
There was a problem hiding this comment.
We should already use torch.cat
Maybe it is cleaner like this, handling the first torch.cat when the buffer is empty would require extra complexity and less readable code ?
| logger.debug(f"Applied decoding compression: " f"keys.shape: {keys.shape}, values.shape: {values.shape}") | ||
|
|
||
| # Update cache with compressed keys and values | ||
| if isinstance(cache, QuantizedCache): |
There was a problem hiding this comment.
Again: Could become a dedicated util function.
| [project] | ||
| name = "kvpress" | ||
| version = "0.3.0" | ||
| version = "0.3.1" |
There was a problem hiding this comment.
I wouldn't update the version.
Decoding press will become version 1.0.0, with this PR, we are adding this functionality (we won't cut a version yet to be able test this feature more thourpoughly).
| @@ -0,0 +1,111 @@ | |||
| # Generation Presses (Experimental) | |||
There was a problem hiding this comment.
We can move this to the main readme in a new dropdown section.
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
|
/ok to test a0053c9 |
1 similar comment
|
/ok to test a0053c9 |
| @@ -132,13 +134,7 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic | |||
|
|
|||
| cache_layer = cache.layers[module.layer_idx] | |||
| if isinstance(cache, QuantizedCache): | |||
There was a problem hiding this comment.
We could also extract keys, values = extract_key_values(cache_layer).
WDYT?
There was a problem hiding this comment.
Do you mean just
def extract_key_values(layer): return layer.keys, layer.values
Maybe a bit overkill in the end we are just accessing two fields ?
There was a problem hiding this comment.
Maybe a bit overkill in the end we are just accessing two fields ?
The same If-else block is present in several parts of the code.
To me, it makes thus sense to extract the whole block.
There was a problem hiding this comment.
I see, makes sense!
|
/ok to test a2982f9 |
|
|
||
| with press(self.model) if press is not None else contextlib.nullcontext(): | ||
| # We only perform prefill compression if the press is not a decoding or prefill decoding press | ||
| perform_prefill_compression = press is not None and not isinstance(press, (DecodingPress, PrefillDecodingPress)) |
There was a problem hiding this comment.
PrefillDecodingPress needs to be excluded.
There was a problem hiding this comment.
Right, my bad 🫠
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
|
@maxjeblick could you also check the README, if it is clear enough ? |
|
/ok to test 31cf83c |
|
Moved PR to #139 due to DCO issues. |
PR description
(Not ready to merge)
This PR introduces decoding time compression ( #55 ) and includes significant contributions from @maxjeblick (Thanks Max ! 🙏).
The main additions are 2 presses, Decoding and PrefillDecoding, that perform decoding time compression. Apart from standard code review, some things to discuss:
Checklist
make test)make style, on errors try fix withmake format)git commit -smypress_press.pyis in thepressesdirectoryMyPressis in__init__.pyREADME.mdis updated with a 1 liner about the new press in the Available presses sectiondefault_presseslist intests/default_presses.py