# Head-specific compression
In this notebook, we demonstrate how to implement head-specific compression using an example of Ada-SnapKV from (AdaKV)[arxiv.org/abs/2407.11550]. Before proceeding, please refer to the file `new_press.ipynb` to understand how standard compression works.

### Key observation

Different attention heads within LLMs exhibit significant disparities in their attention patterns, such as varying degrees of attention concentration. This enables us to distribute the overall budget across different attention heads strategically. For example, according to AdaKV, we allocate more budget to attention heads with distributed attention and reduce the compression budget for heads exhibiting concentrated attention.

### How to Achieve Head-specific Compression in Practice?
Head-specific compression offers significant advantages under the same budget compared to standard compression. However, it introduces challenges in managing cache length differences across heads. To address these challenges, we provide a solution based on a flattened cache layout `DynamicCacheSplitHeadFlatten`, complemented by:

* Custom CUDA kernels for cache update operations: `update_flatten_klenN_view`
* Flash Attention techniques supporting variable-length cache computations:  `flash_attn_varlen_func`


In [1]:
import torch
from torch import nn
from transformers import pipeline
from kvpress import BasePress, AdaSnapKVPress

context = "In this step-by-step guide, you will learn how to create a new press in kvpress !"
question = "\nWhat is the purpose of this guide?"

## Overview of How to Use Head-Specific Compression
Here’s an example using Ada-SnapKV(AdaKV) to illustrate the process:

1. Replace the Standard Flash Attention with Variable-Length Flash Attention **Before loading the LLM**.
2. Instantiate a head-specific compression (e.g. Ada-SnapKV ) and Integrate it into the Pipeline


In [2]:

from kvpress.ada_attn import replace_var_flash_attn


device = "cuda:0"
# ckpt = "meta-llama/Meta-Llama-3.1-8B-Instruct"
ckpt = "/raid/share_files/models/Meta-Llama-3.1-8B-Instruct"

# [Important] replace the vanilla flash attention with the flash_attn_varlen_func 
attn_implementation = "flash_attention_2"
replace_var_flash_attn(model=ckpt)

# window_size 2 for eaiser demonstration
ada_snapkv_press = AdaSnapKVPress()
ada_snapkv_press.compression_ratio = 0.5
ada_snapkv_press.window_size=2

# construct the pipeline
pipe = pipeline("kv-press-text-generation", model=ckpt, device=device,torch_dtype="auto", model_kwargs={"attn_implementation":attn_implementation})


You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [3]:
# generate text
tokens = pipe.tokenizer(context, return_tensors="pt").to(device)
print("Answer:",pipe(context, question=question,press=ada_snapkv_press)["answer"])

The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.


Answer: The purpose of this guide is to teach users how to create a new press in kvpress, which is likely a software or tool used for creating and managing press releases or other types of content.


## A Closer Look at the Role of Flattened Cache Layout in Head-Specific Compression

regular cache layout: [bsz, head_num, seqlen, head_dim] -> flattened cache layout: [ (seqlen in head1, bsz1) + (seqlen in head2, bsz1) + (seqlen in head3, bsz1) ... + (seqlen in head_num, bsz), head_dim]



In [4]:
from kvpress.ada_cache import DynamicCacheSplitHeadFlatten

# DynamicCacheSplitHeadFlatten Class is used to flatten the cache for all heads
flatten_cache = DynamicCacheSplitHeadFlatten()
with torch.no_grad():
    outputs_without_press = pipe.model(**tokens, past_key_values=flatten_cache)

flatten_cache = DynamicCacheSplitHeadFlatten()
with torch.no_grad(), ada_snapkv_press(pipe.model):
    output_with_press = pipe.model(**tokens, past_key_values=flatten_cache)

print(f"There are {len(outputs_without_press.past_key_values.metadata_list[0].head_lens)} heads with the same lengths of cached tokens: {outputs_without_press.past_key_values.metadata_list[0].head_lens}")
print(f"Flatten Cache shape without press: {outputs_without_press.past_key_values.key_cache[0].shape}\n")

print(f"After head-specific compression, there are {len(output_with_press.past_key_values.metadata_list[0].head_lens)} heads with different lengths of cached tokens: {output_with_press.past_key_values.metadata_list[0].head_lens}")
print(f"Flatten Cache shape with press: {output_with_press.past_key_values.key_cache[0].shape}\n")

compress_ratio = 1 - output_with_press.past_key_values.key_cache[0].shape[0] / outputs_without_press.past_key_values.key_cache[0].shape[0]

print(f"Overall compression ratio: 1 - sum({outputs_without_press.past_key_values.metadata_list[0].head_lens.cpu().tolist()})/sum({output_with_press.past_key_values.metadata_list[0].head_lens.cpu().tolist()}) = {compress_ratio}\n")

# The `KVPressTextGenerationPipeline` simply applies the `press` as above on the context tokens (see `_forward` method for more details).
print("Answer:",pipe(context, question=question, press=ada_snapkv_press)["answer"])

There are 8 heads with the same lengths of cached tokens: tensor([21, 21, 21, 21, 21, 21, 21, 21], device='cuda:0', dtype=torch.int32)
Flatten Cache shape without press: torch.Size([168, 128])

After head-specific compression, there are 8 heads with different lengths of cached tokens: tensor([ 9, 12, 11, 14,  5, 15, 13,  5], device='cuda:0', dtype=torch.int32)
Flatten Cache shape with press: torch.Size([84, 128])

Overall compression ratio: 1 - sum([21, 21, 21, 21, 21, 21, 21, 21])/sum([9, 12, 11, 14, 5, 15, 13, 5]) = 0.5

Answer: The purpose of this guide is to teach users how to create a new press in kvpress, which is likely a software or tool used for creating and managing press releases or other types of content.


## Create your own head-spefic cache compression method
1. [What You Need to Do]: Implement a new cache compression method by inheriting from the `AdaBasePress` class. For the new head-specific compression method, it is recommended to generate a masked score to directly mask the KV cache pairs you wish to retain in each head, setting their score to the maximum value.
3. The `AdaBasePress` class `forward_hook` will, by default, retain the highest-scoring cache pairs across all heads within a layer, updating the flattened cache and corresponding metadata.

Below, MyAdaPress demonstrates a simple example that retains the KV cache in heads 1 and 2 in full, while the other heads retain cache according to the StreamingLLM method (i.e., keeping the 4 sink token caches and the caches within the recent window).

In [None]:
from kvpress.presses.base_press import AdaBasePress


class MyAdaPress(AdaBasePress):
    def score(
        self,
        module: nn.Module,
        hidden_states: torch.Tensor,
        keys: torch.Tensor,
        values: torch.Tensor,
        attentions: torch.Tensor,
        kwargs,
    ) -> torch.Tensor:
        
        cache_metadata = kwargs.get("metadata", None)
        assert cache_metadata is not None, "cache_metadata is required for AdaPress"

        # Convert to (bsz, num_key_value_heads, q_len, head_dim) for easier scoring
        # Since this is the first compression, we can easily flatten the key and value head dimensions for score computation
        # In multi-turn compression scenarios, extra care is needed as the key and value head dimensions may not be restructured
        keys = keys.view(cache_metadata.bsz, cache_metadata.num_key_value_heads, -1, keys.shape[-1])
        values = values.view(cache_metadata.bsz, cache_metadata.num_key_value_heads, -1, keys.shape[-1])
        seq_len = keys.shape[-2]

        # initialize scores
        scores = torch.arange(seq_len, device=keys.device).float()
        scores = scores.unsqueeze(0).unsqueeze(0).repeat(cache_metadata.bsz, cache_metadata.num_key_value_heads, 1)

        max_value = torch.finfo(scores.dtype).max

        # mask for attn sink
        scores[:,:,:4] = max_value

        # mask the scores for all cache in head1 and head2
        scores[:, :2, :] = max_value
        
        return scores


press = MyAdaPress(0.5)

flatten_cache = DynamicCacheSplitHeadFlatten()
with torch.no_grad():
    outputs_without_press = pipe.model(**tokens, past_key_values=flatten_cache)

flatten_cache = DynamicCacheSplitHeadFlatten()
with torch.no_grad(), press(pipe.model):
    output_with_press = pipe.model(**tokens, past_key_values=flatten_cache)

print(f"There are {len(outputs_without_press.past_key_values.metadata_list[0].head_lens)} heads with the same lengths of cached tokens: {outputs_without_press.past_key_values.metadata_list[0].head_lens}")
print(f"Flatten Cache shape without press: {outputs_without_press.past_key_values.key_cache[0].shape}\n")

print(f"After head-specific compression, there are {len(output_with_press.past_key_values.metadata_list[0].head_lens)} heads with different lengths of cached tokens: {output_with_press.past_key_values.metadata_list[0].head_lens}")
print(f"Flatten Cache shape with press: {output_with_press.past_key_values.key_cache[0].shape}\n")

compress_ratio = 1 - output_with_press.past_key_values.key_cache[0].shape[0] / outputs_without_press.past_key_values.key_cache[0].shape[0]
print(f"Overall compression ratio: 1 - sum({outputs_without_press.past_key_values.metadata_list[0].head_lens.cpu().tolist()})/sum({output_with_press.past_key_values.metadata_list[0].head_lens.cpu().tolist()}) = {compress_ratio}\n")

print("Answer:", pipe(context, question=question, press=press)["answer"])

There are 8 heads with the same lengths of cached tokens: tensor([21, 21, 21, 21, 21, 21, 21, 21], device='cuda:0', dtype=torch.int32)
Flatten Cache shape without press: torch.Size([168, 128])

After head-specific compression, there are 8 heads with different lengths of cached tokens: tensor([21, 21,  7,  7,  7,  7,  7,  7], device='cuda:0', dtype=torch.int32)
Flatten Cache shape with press: torch.Size([84, 128])

Overall compression ratio: 1 - sum([21, 21, 21, 21, 21, 21, 21, 21])/sum([21, 21, 7, 7, 7, 7, 7, 7]) = 0.5

Answer: This guide is intended to help users create a new press in kvpress, which is a software for creating and managing press releases. The purpose of this guide is to provide step-by-step instructions on how to create a new press in kvpress, allowing
