Skip to content

Add decoding time compression#138

Closed
alessiodevoto wants to merge 181 commits intomainfrom
aledev/decoding_press
Closed

Add decoding time compression#138
alessiodevoto wants to merge 181 commits intomainfrom
aledev/decoding_press

Conversation

@alessiodevoto
Copy link
Copy Markdown
Collaborator

PR description

(Not ready to merge)
This PR introduces decoding time compression ( #55 ) and includes significant contributions from @maxjeblick (Thanks Max ! 🙏).

The main additions are 2 presses, Decoding and PrefillDecoding, that perform decoding time compression. Apart from standard code review, some things to discuss:

  • Where to put the documentation for decoding time compression in the README.md (right now I left the original comments in the generation.md file, but it has to be moved). We need to be extra careful to make sure it is not confusing.
  • Right now the evaluation code needs to be refactored for supporting decoding times compression evaluation. We will need to add benchmarks and change the eval loop slightly. We can address this in a future PR.

Checklist

  • Tests are working (make test)
  • Code is formatted correctly (make style, on errors try fix with make format)
  • Copyright header is included
  • All commits are signed-off using git commit -s
  • (new press) mypress_press.py is in the presses directory
  • (new press) MyPress is in __init__.py
  • (new press) README.md is updated with a 1 liner about the new press in the Available presses section
  • (new press) New press is in the default_presses list in tests/default_presses.py
  • (new press) A docstring is provided that follows the same structure as the existing ones

maxjeblick added 30 commits July 3, 2025 16:42
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: Maximilian Jeblick <maximilianjeblick@gmail.com>
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Sep 29, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@alessiodevoto
Copy link
Copy Markdown
Collaborator Author

/ok to test a43fc19

@alessiodevoto
Copy link
Copy Markdown
Collaborator Author

Is this normal @maxjeblick ?

tests/integration/test_ruler.py ssssssssssssssssssssssssssssssssssssssss [ 17%]

Comment thread kvpress/pipeline.py
logger.debug(f"Compressed Context Length: {cache.get_seq_length()}")

# Greedy decoding for each question
answers = []
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we don't exit the context manager after prefilling. This may break kvzip press.

Copy link
Copy Markdown
Collaborator

@maxjeblick maxjeblick Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A straightforward solution might be to have two context mangers

should_perform_prefill_compression = press is not None or not isinstance(press, (DecodingPress, PrefillDecodingPress)
with press(self.model) if not should_prefill_compression else contextlib.nullcontext():

(and a subsequent with block for answer generation). By this, we 100% ensure same control flow in case no decoding press is used.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you @maxjeblick pointed out KVZip is not supported because it is not a scorer press

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok sorry @maxjeblick I missed the problem, I tried the 2 context managers approach, wdyt ?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks good!

Comment thread evaluation/evaluate_decoding.py Outdated
Copy link
Copy Markdown
Collaborator

@maxjeblick maxjeblick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for working on decoding press and adapting the code!

Some comments:

  • IMO, decoding press can be refactored a bit, in particulat cache handling (qunatized/non quantized) now appears in various presses and can be factored out.
  • The notebook needs probably be rerun to produce a correct output (max generation length is too low). The generated text output can also be formatted nice for display
  • The PR changes the pipeline logic, the with press context manager now exists AFTER generation, not before. This will most likely cause kvzip press to not work any longer (it relies on the context manager to exit after prefilling).

Comment thread kvpress/presses/decoding_press.py Outdated
Target number of tokens to keep after compression.
hidden_states_buffer_size : int, default=128
Maximum number of hidden states to keep before compression. Larger values use more GPU memory.
NoteSome presses don't need buffered hidden states and can set this to 0 to use only the
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

Comment thread kvpress/presses/decoding_press.py
Comment thread kvpress/presses/decoding_press.py
Comment thread kvpress/presses/decoding_press.py Outdated
)

cache_layer = cache.layers[module.layer_idx]
if isinstance(cache, QuantizedCache):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thiis is a candidate for refactoring, as it appears also in the base presses.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to utils

return output
# print(f"Adding hidden states to buffer: {hidden_states.shape}")
# Add current hidden states to buffer for this layer
self.hidden_states_buffer[layer_idx].append(hidden_states.detach().clone())
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hidden states buffer might be longer than hidden_states_buffer_size.
IMO, the code makes sense; we may need to adapt the docstring to be more explict.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should already use torch.cat, s.t. self.hidden_states_buffer[layer_idx] is always a tensor; makes it more easy to use.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We say that in the doctring: "Buffered hidden states from recent decoding steps (shape: [batch, buffer_len, hidden_dim])"

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should already use torch.cat
Maybe it is cleaner like this, handling the first torch.cat when the buffer is empty would require extra complexity and less readable code ?

logger.debug(f"Applied decoding compression: " f"keys.shape: {keys.shape}, values.shape: {values.shape}")

# Update cache with compressed keys and values
if isinstance(cache, QuantizedCache):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again: Could become a dedicated util function.

Comment thread kvpress/presses/prefill_decoding_press.py
Comment thread pyproject.toml Outdated
[project]
name = "kvpress"
version = "0.3.0"
version = "0.3.1"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't update the version.
Decoding press will become version 1.0.0, with this PR, we are adding this functionality (we won't cut a version yet to be able test this feature more thourpoughly).

Comment thread kvpress/presses/generation/README.md Outdated
@@ -0,0 +1,111 @@
# Generation Presses (Experimental)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can move this to the main readme in a new dropdown section.

Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
@alessiodevoto
Copy link
Copy Markdown
Collaborator Author

alessiodevoto commented Oct 1, 2025

/ok to test a0053c9

1 similar comment
@alessiodevoto
Copy link
Copy Markdown
Collaborator Author

/ok to test a0053c9

Comment thread kvpress/presses/base_press.py Outdated
@@ -132,13 +134,7 @@ def forward_hook(self, module: nn.Module, input: list[torch.Tensor], kwargs: dic

cache_layer = cache.layers[module.layer_idx]
if isinstance(cache, QuantizedCache):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also extract keys, values = extract_key_values(cache_layer).
WDYT?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean just
def extract_key_values(layer): return layer.keys, layer.values

Maybe a bit overkill in the end we are just accessing two fields ?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a bit overkill in the end we are just accessing two fields ?

The same If-else block is present in several parts of the code.
To me, it makes thus sense to extract the whole block.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, makes sense!

Comment thread kvpress/presses/decoding_press.py
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
@alessiodevoto
Copy link
Copy Markdown
Collaborator Author

/ok to test a2982f9

Comment thread kvpress/pipeline.py Outdated

with press(self.model) if press is not None else contextlib.nullcontext():
# We only perform prefill compression if the press is not a decoding or prefill decoding press
perform_prefill_compression = press is not None and not isinstance(press, (DecodingPress, PrefillDecodingPress))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PrefillDecodingPress needs to be excluded.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, my bad 🫠

Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
@alessiodevoto
Copy link
Copy Markdown
Collaborator Author

@maxjeblick could you also check the README, if it is clear enough ?

@alessiodevoto
Copy link
Copy Markdown
Collaborator Author

/ok to test 31cf83c

@maxjeblick
Copy link
Copy Markdown
Collaborator

Moved PR to #139 due to DCO issues.
@alessiodevoto I added you as co-author in b16417b which contains the content of this PR.

@maxjeblick maxjeblick closed this Oct 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.