# Import

In [None]:
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from rkv.monkeypatch import replace_llama, replace_qwen2, replace_qwen3

# Prompt Template

prompt_template = "You are given a math problem.\n\nProblem: {question}\n\n You need to solve the problem step by step. First, you need to provide the chain-of-thought, then provide the final answer.\n\n Provide the final answer in the format: Final answer:  \\boxed{{}}"

In [None]:
prompt_template = "You are given a math problem.\n\nProblem: {question}\n\n You need to solve the problem step by step. First, you need to provide the chain-of-thought, then provide the final answer.\n\n Provide the final answer in the format: Final answer:  \\boxed{{}}"

with open("../data/gsm8k.jsonl", "r") as f:
    test_data = [json.loads(line) for line in f]

sample_id = 1281
sample = test_data[sample_id]
prompt = prompt_template.format(question=sample["question"])
answer = sample["answer"].split("####")[-1].strip()

# Select Model

Choose from:
- deepseek-ai/DeepSeek-R1-Distill-Llama-8B
- deepseek-ai/DeepSeek-R1-Distill-Qwen-7B
- deepseek-ai/DeepSeek-R1-Distill-Qwen-14B
- deepseek-ai/DeepSeek-R1-Distill-Qwen-32B

In [None]:
model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"

# Method Configuration

- Choose method from: rkv, snapkv, streamingllm, h2o, analysiskv
- Choose budget from: 128, 256, 512, 1024
- Choose mix_lambda from 0 to 1.

- When mix_lambda=0, the selection is dominated by redundency.
- When mix_lambda=1, the selection is dominated by attention.

- Analysiskv is a special method that we do not compress KV but only return the selection patterns. In this way, we could observe the importance score without compressing KV.

In [None]:
compression_config = {
    "method": "rkv",
    "method_config": {
        "budget": 128,
        "window_size": 8,
        "mix_lambda": 0.07,
        "retain_ratio": 0.2,
        "retain_direction": "last",
        "record_kept_token_indices": True,
    },
    "compression": None,
    "update_kv": True
}

model_config = {
    "divide_method": "newline",
    "divide_length": 128,
    "compression_content": "all",
}

## Load Model

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    model_name, use_fast=True, padding_side="left"
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

if 'qwen3' in model_name.lower():
    replace_qwen3(compression_config)
elif 'llama' in model_name.lower():
    replace_llama(compression_config)
else:  # qwen2
    replace_qwen2(compression_config)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
).eval()

model.config.update(model_config)

model.newline_token_ids = [
    tokenizer.encode("\n")[-1],
    tokenizer.encode(".\n")[-1],
    tokenizer.encode(")\n")[-1],
    tokenizer.encode("\n\n")[-1],
    tokenizer.encode(".\n\n")[-1],
    tokenizer.encode(")\n\n")[-1],
]

model.after_think_token_ids = [
    tokenizer.encode("</think>")[-1],
]

device = torch.device("cuda:4")
model.to(device)

# Inference

In [None]:
inputs = tokenizer(prompt, return_tensors="pt").to(device)

outputs = model.generate(
    **inputs,
    max_length=8192,
    num_beams=1,
    do_sample=False,
)

## Display model answer
print("Sample ID: ", sample_id)
print(
    tokenizer.decode(outputs[0], skip_special_tokens=True),
)
print("\n\nGround Truth: ", answer)
print("Generation length:", len(outputs[0]) - inputs.input_ids.shape[1])
print(
    "Compression Steps:",
    len(model.model.layers[-1].self_attn.kv_cluster.kept_token_indices),
)
print("Evicted tokens:", model.model.layers[-1].self_attn.kv_cluster.evicted_token_num)
torch.cuda.empty_cache()

# Visualize token eviction

## Visualize the token eviction pattern for a given head for one compression step

In [None]:
from rkv.utils import visualize_token_eviction

layer_id = 31
head_id = 6
step_idx = 5

kept_indices_lst = model.model.layers[layer_id].self_attn.kv_cluster.kept_token_indices
visualize_token_eviction(
    outputs[0], kept_indices_lst[step_idx], tokenizer, head_idx=head_id
)

## Visualize the token eviction pattern for a given heads at each compression step

In [None]:
from rkv.utils import visualize_multistep_token_eviction

layer_id = 31
head_id = 6
step_idx = 5

kept_indices_lst = model.model.layers[layer_id].self_attn.kv_cluster.kept_token_indices
visualize_multistep_token_eviction(
    outputs[0], kept_indices_lst, tokenizer, head_idx=head_id, step_idx=step_idx
)

## Visualize the token eviction pattern for all heads at each compression step

In [None]:
from rkv.utils import visualize_multistep_token_eviction_by_head

layer_id = 31
step_idx = 5

kept_indices_lst = model.model.layers[layer_id].self_attn.kv_cluster.kept_token_indices

print("Total Step: ", len(kept_indices_lst))
visualize_multistep_token_eviction_by_head(
    outputs[0], kept_indices_lst, tokenizer, step_idx=step_idx, aggregate=True
)

# aggregate: when set to False, later heads will cover previous heads. when set to `True`, will compute how many times a token are covered by a head.

## Visualize the token eviction score for all heads at each compression step

In [None]:
from rkv.utils import visualize_multistep_token_eviction_score_by_head

layer_id = 31
head_id = 6
step_idx = 5

kept_indices_lst = model.model.layers[layer_id].self_attn.kv_cluster.kept_token_indices
kept_attention_scores_lst = model.model.layers[layer_id].self_attn.kv_cluster.kept_attention_scores

visualize_multistep_token_eviction_score_by_head(
    outputs[0], kept_indices_lst, kept_attention_scores_lst, tokenizer, step_idx=step_idx, head_idx=head_idx,
)