Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
97 changes: 96 additions & 1 deletion docs/source/user-guide/sparse-attention/esa.md
Original file line number Diff line number Diff line change
@@ -1 +1,96 @@
# ESA
# ESA: A Simple Example of Sparse Attention Implementation Based on UCM

## 🔍 Overview

ESA provides developers with an intuitive example of how to implement their own sparse attention algorithms based on the UCM framework. It includes the following components: KV block representation computation, asynchronous retrieval of the top-K relevant blocks, and non-blocking loading of KV blocks from SSD to HBM.

## 🚦 Quick Start

### Basic Usage
ESA can be launched using the following command:
```shell
export MODEL_PATH="/path/to/model" # For example: /home/models/Qwen2.5-14B-Instruct
export DATASET_PATH="/path/to/longbench/multifieldqa_zh.jsonl" # For example: /home/data/Longbench/data/multifieldqa_zh.jsonl
python examples/offline_inference_esa.py
```
ESA can be configured by modifying `ucm_sparse_config` in `examples/offline_inference_esa.py`.
```python
...
ktc = KVTransferConfig(
kv_connector=name,
kv_connector_module_path=module_path,
kv_role="kv_both",
kv_connector_extra_config={
"ucm_connector_name": "UcmNfsStore",
"ucm_connector_config": {
"storage_backends": "/path/to/data",
"kv_block_size": 33554432,
},
"ucm_sparse_config": {
"ESA": {
"init_window_sz": 1,
"local_window_sz": 2,
"min_blocks": 4,
"sparse_ratio": 0.3,
"retrieval_stride": 5,
}
},
},
)
...
```

## 🎯 Key Design

- KV Block Representation Computation
ESA applies the `mean` function along the block size axis to obtain the representation of each KV block.

- Asynchronous Retrieval and Loading
During the decoding stage, ESA periodically updates the context KV blocks based on `retrieval_stride` in the `ucm_sparse_config`, with a default value of `5`. ESA employs fine-grained scheduling for asynchronous tasks.
<p align="center">
<img alt="UCM" src="../../_static/images/esa_async_retrieval_and_load.png" width="80%">
</p>

In the second step of each period, the retrieval of the most important KV blocks is initiated. The pseudocode is as follows:
```python
def start_retrieval(self, query, forward_context):
self.retrieval_task = self.retrieval_worker.submit(
query, kv_block_representations=kv_block_representations
)
```
Then, in the last step of the current period, we wait for the retrieval_worker to complete and retrieve the most relevant blocks to load. The pseudocode is:
```python
def wait_retrieval_and_start_load(self):
topk_blocks = self.retrieval_task.result()
self.loading_task = self.launch_transfer_task(
"load", topk_blocks, target_HBM_addresses
)
```
Finally, at the beginning of the next period, the transfer task is synchronized, and the KV caches in HBM are updated. The pseudocode is:
```python
def wait_transfer_task_done(self):
ret = self.store_instance.wait(self.loading_task)
```

## 🔥 Results
The following results were obtained using `Qwen2.5-14B-Instruct` under the specified hyperparameters:
```python
"ucm_sparse_config": {
"ESA": {
"init_window_sz": 1,
"local_window_sz": 2,
"min_blocks": 4,
"sparse_ratio": 0.3,
"retrieval_stride": 5
}
},
```

### 🏆 Performance

### 📈 Accuracy
We use [LongBench](https://huggingface.co/datasets/zai-org/LongBench) to evaluate the accuracy of the ESA algorithm.
| Dataset | F1-Score |
|-------|-----------|
| multifieldqa_zh | 59.4 |
| dureader | 26.4 |
57 changes: 46 additions & 11 deletions examples/offline_inference_esa.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import contextlib
import json
import os
import sys
import time
from dataclasses import asdict

Expand All @@ -13,15 +14,50 @@

from ucm.logger import init_logger

MODEL_PATH = "/home/models/Qwen2.5-14B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_chat_template=True)
logger = init_logger(__name__)
model = ""
path_to_dataset = ""
data_dir = ""
tokenizer = None


def setup_environment_variables():
os.environ["VLLM_USE_V1"] = "1"
os.environ["PYTHONHASHSEED"] = "123456"

global model, path_to_dataset, data_dir, tokenizer
model = os.getenv("MODEL_PATH", "/home/models/Qwen2.5-14B-Instruct")
if not os.path.isdir(model):
model = input("Enter path to model, e.g. /home/models/Qwen2.5-14B-Instruct: ")
if not os.path.isdir(model):
print("Exiting. Incorrect model_path")
sys.exit(1)

path_to_dataset = os.getenv(
"DATASET_PATH", "/home/data/Longbench/data/multifieldqa_zh.jsonl"
)
if not os.path.isfile(path_to_dataset):
path_to_dataset = input(
"Enter path to one of the longbench dataset, e.g. /home/data/Longbench/data/multifieldqa_zh.jsonl: "
)
if not os.path.isfile(path_to_dataset):
print("Exiting. Incorrect dataset path")
sys.exit(1)

data_dir = os.getenv("DATA_DIR", "/home/data/kv_cache")
data_dir = input(
"Enter the directory for UCMStore to save kv cache, e.g. /home/data/kv_cache: "
)
if not os.path.isdir(data_dir):
create = input(f"Directory {data_dir} dose not exist. Create it? (Y/n): ")
if create.lower() == "y":
os.makedirs(data_dir, exist_ok=True)
else:
print("Exiting. Directory not created.")
sys.exit(1)

tokenizer = AutoTokenizer.from_pretrained(model, use_chat_template=True)


@contextlib.contextmanager
def build_llm_with_uc(module_path: str, name: str, model: str):
Expand All @@ -32,7 +68,7 @@ def build_llm_with_uc(module_path: str, name: str, model: str):
kv_connector_extra_config={
"ucm_connector_name": "UcmNfsStore",
"ucm_connector_config": {
"storage_backends": "/home/data",
"storage_backends": data_dir,
"kv_block_size": 33554432,
},
"ucm_sparse_config": {
Expand All @@ -51,12 +87,12 @@ def build_llm_with_uc(module_path: str, name: str, model: str):
model=model,
kv_transfer_config=ktc,
max_model_len=32768,
gpu_memory_utilization=0.8,
gpu_memory_utilization=0.6,
max_num_batched_tokens=30000,
block_size=128,
enforce_eager=True,
distributed_executor_backend="mp",
tensor_parallel_size=2,
tensor_parallel_size=1,
)

llm = LLM(**asdict(llm_args))
Expand Down Expand Up @@ -85,8 +121,6 @@ def print_output(
def main():
module_path = "ucm.integration.vllm.uc_connector"
name = "UnifiedCacheConnectorV1"
model = os.getenv("MODEL_PATH", "/home/models/Qwen2.5-14B-Instruct")

setup_environment_variables()

def get_prompt(prompt):
Expand All @@ -106,10 +140,11 @@ def get_prompt(prompt):

with build_llm_with_uc(module_path, name, model) as llm:
prompts = []

batch_size = 1

with open("/home/Longbench/data/multifieldqa_zh.jsonl", "r") as f:
batch_size = 5
assert os.path.isfile(
path_to_dataset
), f"Incorrect dataset path. Please specify the dataset path by `export DATASET_PATH=/path/to/longbench/multifieldqa_zh.jsonl`"
with open(path_to_dataset, "r") as f:
for _ in range(batch_size):
line = f.readline()
if not line:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def build_cmake(self, ext: CMakeExtension):
"cmake",
"-DCMAKE_BUILD_TYPE=Release",
f"-DPYTHON_EXECUTABLE={sys.executable}",
f"-DPYTHON3_EXECUTABLE={sys.executable}",
]

if _is_cuda():
Expand Down
90 changes: 38 additions & 52 deletions ucm/ucm_sparse/esa.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,26 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole):
backend = retrieval_backend.RetrievalWorkerBackend(backend_src)
self.retrieval_workers.append(RetrievalWorker(backend))

def create_layerwise_req_state(self, req_meta, layer_name):
layer_id = int(layer_name.split(".")[2])
if req_meta.request_id not in self.req_states:
if self.req_states.get(req_meta.request_id) is None:
self.req_states[req_meta.request_id] = [
None
] * self.total_num_hidden_layers
if self.req_states[req_meta.request_id][layer_id] is None:
self.req_states[req_meta.request_id][layer_id] = ReqStatePerLayer(
req_meta,
layer_name,
self.rank,
self.tp_size,
self.connector,
self._vllm_config,
self.retrieval_workers[layer_id],
self.layer_pools[layer_id],
)
return self.req_states[req_meta.request_id][layer_id]

def attention_begin(
self,
query: torch.Tensor,
Expand All @@ -453,24 +473,7 @@ def attention_begin(
forward_context: ForwardContext,
) -> None:
for req_meta in self._sparse_metadata.requests:
layer_id = int(layer_name.split(".")[2])
if req_meta.request_id not in self.req_states:
if self.req_states.get(req_meta.request_id) is None:
self.req_states[req_meta.request_id] = [
None
] * self.total_num_hidden_layers
if self.req_states[req_meta.request_id][layer_id] is None:
self.req_states[req_meta.request_id][layer_id] = ReqStatePerLayer(
req_meta,
layer_name,
self.rank,
self.tp_size,
self.connector,
self._vllm_config,
self.retrieval_workers[layer_id],
self.layer_pools[layer_id],
)
req_state = self.req_states[req_meta.request_id][layer_id]
req_state = self.create_layerwise_req_state(req_meta, layer_name)
req_state.update_meta(req_meta)
req_state.attention_begin(query, key, value, forward_context)

Expand All @@ -484,29 +487,18 @@ def attention_finished(
forward_context: ForwardContext,
) -> None:
for req_meta in self._sparse_metadata.requests:
layer_id = int(layer_name.split(".")[2])
if req_meta.request_id not in self.req_states:
if self.req_states.get(req_meta.request_id) is None:
self.req_states[req_meta.request_id] = [
None
] * self.total_num_hidden_layers
if self.req_states[req_meta.request_id][layer_id] is None:
self.req_states[req_meta.request_id][layer_id] = ReqStatePerLayer(
req_meta,
layer_name,
self.rank,
self.tp_size,
self.connector,
self._vllm_config,
self.retrieval_workers[layer_id],
self.layer_pools[layer_id],
)
req_state = self.req_states[req_meta.request_id][layer_id]
req_state = self.create_layerwise_req_state(req_meta, layer_name)
req_state.update_meta(req_meta)
req_state.attention_finished(
query, key, value, attn_output, forward_context
)

def is_sparsed_request(self, req):
return (
len(req.prompt_token_ids)
>= self._vllm_config.cache_config.block_size * self.esa_cfg["min_blocks"]
)

def build_sparse_meta(
self, scheduler_output, requests, input_batch, attn_metadata
) -> UcmSparseMetadata:
Expand All @@ -515,41 +507,35 @@ def build_sparse_meta(
req_id,
num_scheduled_tokens,
) in scheduler_output.num_scheduled_tokens.items():
req_state = requests[req_id]
if (
len(req_state.prompt_token_ids)
<= self._vllm_config.cache_config.block_size
):
return

req = requests[req_id]
if not self.is_sparsed_request(req):
continue
if isinstance(attn_metadata, dict):
attn_metadata = next(iter(attn_metadata.values()))
sparse_meta.add_request(
req_id,
input_batch.req_id_to_index[req_id],
num_scheduled_tokens,
req_state.num_computed_tokens,
req_state.block_ids[0],
req.num_computed_tokens,
req.block_ids[0],
attn_metadata.query_start_loc[input_batch.req_id_to_index[req_id]],
req_state.prompt_token_ids,
req_state.output_token_ids,
req.prompt_token_ids,
req.output_token_ids,
)
self._sparse_metadata = sparse_meta

def request_begin(self, request_id: ReqType, prompt_token_ids: List[int]):
pass

def request_finished_in_worker(self, request_id: ReqType):
if request_id not in self.req_states:
return
for layer_state in self.req_states[request_id]:
layer_state.repre_pool.free(layer_state.slots)
del self.req_states[request_id]

def estimate_num_slots_sparsed(self, request: Request) -> int:
if (
request.num_output_tokens == 0
or request.num_prompt_tokens
< self._vllm_config.cache_config.block_size * self.esa_cfg["min_blocks"]
):
if request.num_output_tokens == 0 or not self.is_sparsed_request(request):
return INVALID_SLOT
prompt_len = request.num_prompt_tokens
output_len = request.num_output_tokens
Expand Down