This notebook runs fullkv on GSM 8K

## Setup

In [None]:
import os
import time
import json
import asyncio
from tqdm import tqdm
from dataclasses import dataclass

import sglang as sgl

os.makedirs("./z_experiment/evaluation/", exist_ok=True)
os.makedirs("./z_experiment/output/", exist_ok=True)
os.makedirs("./z_experiment/results/", exist_ok=True)


@dataclass
class Config:
    model_path: str = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
    save_path: str = "./z_experiment/output/test_output.jsonl"
    data_size: int = 10


config = Config()

In [None]:
# load data
PROMPT_TEMP = "You are given a math problem.\n\nProblem: {question}\n\n You need to solve the problem step by step. First, you need to provide the chain-of-thought, then provide the final answer.\n\n Provide the final answer in the format: Final answer:  \\boxed{{}}"

test_data = []
prompts = []
with open("./evaluation/data/test_one.jsonl") as f:
    for idx, line in enumerate(f):
        if idx == config.data_size:
            break
        sample = json.loads(line)
        prompt = PROMPT_TEMP.format(question=sample["question"])
        prompts.append(prompt)

        sample["prompt"] = prompt
        sample["index"] = idx

        test_data.append(sample)

## load model

In [None]:
compress_algorithm = "RKV" # "RKV"
compress_max_window = 8
compress_max_prompt = 128
compress_divide_length = 64
compress_divide_method = "step_length" # "step_length", "newline"

llm = sgl.Engine(
    model_path=config.model_path,
    dtype="bfloat16",
    disable_overlap_schedule=True,
    compress_algorithm=compress_algorithm,
    compress_max_window=compress_max_window,
    compress_max_prompt=compress_max_prompt,
    compress_divide_length=compress_divide_length,
    compress_divide_method=compress_divide_method,
)


def main():
    sampling_params = {"temperature": 0.0, "top_p": 0.95, "max_new_tokens": 8192}

    start_time = time.time()
    outputs = llm.generate(prompts, sampling_params)
    end_time = time.time()

    llm.shutdown()

    for sample_idx, output in enumerate(outputs):
        test_data[sample_idx]["output"] = output["text"]
        test_data[sample_idx]["prefill_tokens"] = output["meta_info"]["prompt_tokens"]
        test_data[sample_idx]["output_tokens"] = output["meta_info"][
            "completion_tokens"
        ]
        test_data[sample_idx]["total_tokens"] = (
            output["meta_info"]["prompt_tokens"]
            + output["meta_info"]["completion_tokens"]
        )

    with open(config.save_path, "w") as fp:
        for line in test_data:
            fp.write(json.dumps(line) + "\n")

    total_time = end_time - start_time
    total_tokens_generated = sum(
        output["meta_info"]["completion_tokens"] for output in outputs
    )
    throughput_tokens = total_tokens_generated / total_time
    throughput_requests = len(prompts) / total_time

    print(f"Total execution time: {total_time:.2f} seconds")
    print(f"Throughput (tokens/s): {throughput_tokens:.2f}")
    print(f"Throughput (requests/s): {throughput_requests:.2f}")

In [None]:
prompts = prompts[:100]

if __name__ == "__main__":
    main()

## Evaluation

Use the following commands to evaluate GSM8k results.

```python
python evaluation/math_eval_all_v2.py \
    --exp_name "evaluation" \
    --output_dir "./z_experiment/results" \
    --base_dir "./z_experiment/output" \
    --dataset gsm8k
```