Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions docs/source/getting-started/example/dram_conn.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,99 @@
# DRAM Connector

This document provides a usage example and configuration guide for the **DRAM Connector**. This connector enables offloading of KV cache from GPU HBM to CPU DRAM, helping reduce memory pressure and support larger models or batch sizes.

## Features

The DRAM connector supports the following functionalities:

- `dump`: Offload KV cache blocks from HBM to DRAM.
- `load`: Load KV cache blocks from DRAM back to HBM.
- `lookup`: Look up KV blocks stored in DRAM by block hash.
- `wait`: Ensure that all copy streams between CPU and GPU have completed.
- `commit`: Mark cache operations as complete and ready for reuse.

## Configuration

To use the DRAM connector, you need to configure the `connector_config` dictionary in your model's launch configuration.

### Required Parameters

- `max_cache_size` *(optional)*:
Specifies the maximum allowed DRAM memory usage (in **byte**) for caching in `kv_connector_extra_config["ucm_connector_config"]`.
If not provided, it defaults to **5 GB**.

### Example:

```python
kv_connector_extra_config={"ucm_connector_name": "UcmDram", "ucm_connector_config":{"max_cache_size": 5368709120}}
# Allocate up to 8GB DRAM for KV cache
```

## Launching Inference

### Offline Inference

To start **offline inference** with the DRAM connector,modify the script `examples/vllm_kv_offload.py` to include the `kv_connector_extra_config` for DRAM connector usage:

```python
# In examples/vllm_kv_offload.py
ktc = KVTransferConfig(
...
kv_connector_extra_config={"ucm_connector_name": "UcmDram", "ucm_connector_config":{"max_cache_size": 5368709120}}
)
```

Then run the script as follows:

```bash
cd examples/
python vllm_kv_offload.py
```

### Online Inference

For **online inference** , vLLM with our connector can also be deployed as a server that implements the OpenAI API protocol. Run the following command to start the vLLM server with the Qwen/Qwen2.5-14B-Instruct model:

```bash
vllm serve /home/models/Qwen2.5-14B-Instruct \
--max-model-len 20000 \
--tensor-parallel-size 2 \
--gpu_memory_utilization 0.87 \
--trust-remote-code \
--port 7800 \
--kv-transfer-config \
'{
"kv_connector": "UnifiedCacheConnectorV1",
"kv_connector_module_path": "unifiedcache.integration.vllm.uc_connector",
"kv_role": "kv_both",
"kv_connector_extra_config": {
"ucm_connector_name": "UcmDram",
"ucm_connector_config": {
"max_cache_size": 5368709120
}
}
}'
```

If you see log as below:

```bash
INFO: Started server process [32890]
INFO: Waiting for application startup.
INFO: Application startup complete.
```

Congratulations, you have successfully started the vLLM server with DRAM Connector!

Afrer successfully started the vLLM server,You can interact with the API as following:

```bash
curl http://localhost:7800/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "/home/models/Qwen2.5-14B-Instruct",
"prompt": "Shanghai is a",
"max_tokens": 7,
"temperature": 0
}'
```
164 changes: 164 additions & 0 deletions test/test_ucm_dram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import random
import torch
import unittest
from typing import List
from unittest.mock import MagicMock
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.sampling_params import SamplingParams
from vllm.utils import sha256
from vllm.v1.core.kv_cache_utils import hash_request_tokens
from vllm.v1.request import Request

from unifiedcache.ucm_connector.ucm_dram import UcmDram, DramTask


def make_request(request_id,
prompt_token_ids,
mm_positions=None,
mm_hashes=None,
cache_salt=None):
if mm_positions is None:
multi_modal_inputs = None
else:
multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions)

return Request(
request_id=request_id,
prompt_token_ids=prompt_token_ids,
multi_modal_inputs=multi_modal_inputs,
multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17),
pooling_params=None,
eos_token_id=100,
arrival_time=0,
lora_request=None,
cache_salt=cache_salt,
)


class TestUcmDram(unittest.TestCase):

@classmethod
def setUpClass(cls):
print("===> Before all tests (setUpClass)")

@classmethod
def tearDownClass(cls):
print("===> After all tests (setUpClass)")

def setUp(self):
self.config = {
"block_size": 4
}
self.scheduler_config = {
"role": "scheduler",
"max_cache_size": 1073741824,
"kv_block_size": 262144
}
self.worker_config = {
"role": "worker",
"max_cache_size": 1073741824,
"kv_block_size": 262144
}

self.block_number = 4
self.block_size = int(self.config["block_size"])
self.scheduler_dram = UcmDram(self.scheduler_config)
self.worker_dram = UcmDram(self.worker_config)
random.seed(20250728)
self.request = make_request(
request_id=1,
prompt_token_ids=random.sample(range(0, 10000), self.block_number * self.block_size),
mm_positions=None,
mm_hashes=None,
)
block_hash_types = hash_request_tokens(sha256, self.block_size, self.request)
self.block_hashes: List[str] = [str(x.hash_value) for x in block_hash_types]

def test_look_up_all_hit(self):
'''
Test for all blocks hitten in cache
'''
expected = [True] * len(self.block_hashes)
self.scheduler_dram.cached_blocks.update(self.block_hashes)
actual = self.scheduler_dram.lookup(self.block_hashes)

self.assertEqual(actual, expected)

def test_lookup_partial_hit(self):
'''
Test for part of the blocks hitten in cache
'''
partial_index = random.randint(0, 4)
partial_hashes = self.block_hashes[:partial_index]
self.scheduler_dram.cached_blocks.update(partial_hashes)
actual = self.scheduler_dram.lookup(self.block_hashes)
expected = [True] * partial_index + [False] * (self.block_size - partial_index)
self.assertEqual(actual, expected)

def test_lookup_none_hit(self):
'''
Test for none of the blocks hitten in cache
'''
actual = self.scheduler_dram.lookup(self.block_hashes)
expected = [False] * len(self.block_hashes)
self.assertEqual(actual, expected)

def test_load_success(self):
'''
Test for load from cache successfully
'''
src_tensors = [torch.randint(0, 100, (self.block_size,), dtype=torch.int8)
for _ in range(len(self.block_hashes))]
offsets = [i for i in range(len(self.block_hashes))]
dump_task = self.worker_dram.dump(self.block_hashes, offsets, src_tensors)
self.worker_dram.wait(dump_task)
dst_tensors = [torch.zeros(self.block_size, dtype=torch.int8)
for _ in range(len(self.block_hashes))]
load_task = self.worker_dram.load(self.block_hashes, offsets, dst_tensors)

self.assertIsInstance(load_task, DramTask)
self.assertIsNotNone(load_task.event)
for i, (src_tensor, dst_tensor) in enumerate(zip(src_tensors, dst_tensors)):
self.assertEqual(dst_tensor.shape[0], self.block_size)
self.assertTrue(torch.equal(src_tensor, dst_tensor),
f"Block {i} loaded data is different")

def test_dump_success(self):
'''
Test data dump successfully
'''
src_tensors = [torch.randint(0, 100, (self.block_size,), dtype=torch.int8)
for _ in range(len(self.block_hashes))]
offsets = [i for i in range(len(self.block_hashes))]
original_data = [tensor.clone() for tensor in src_tensors]
dump_task = self.worker_dram.dump(self.block_hashes, offsets, src_tensors)
self.assertIsInstance(dump_task, DramTask)
self.assertIsNotNone(dump_task.event)
self.worker_dram.wait(dump_task)
for i, block_id in enumerate(self.block_hashes):
key = block_id + '_' + str(offsets[i])
cached_data = self.worker_dram.dram_cache[key]
self.assertEqual(cached_data.shape[0], self.block_size)
self.assertTrue(torch.equal(cached_data, original_data[i]))

def test_wait_success(self):
'''
Test wait for task successfully
'''
task = DramTask()
task.event = MagicMock()
result = self.worker_dram.wait(task)
self.assertEqual(result, 0)
task.event.synchronize.assert_called_once()

def test_wait_failure(self):
task = DramTask()
task.event = None
result = self.worker_dram.wait(task)
self.assertEqual(result, -1)


if __name__ == '__main__':
unittest.main()
4 changes: 4 additions & 0 deletions unifiedcache/integration/vllm/uc_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
self.num_layers = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config
)
self.element_size = vllm_config.model_config.dtype.itemsize
if self._vllm_config.kv_transfer_config is not None and \
"ucm_connector_name" in self._vllm_config.kv_transfer_config.kv_connector_extra_config:
name = self._vllm_config.kv_transfer_config.kv_connector_extra_config["ucm_connector_name"]
Expand All @@ -109,6 +110,9 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
config = self._vllm_config.kv_transfer_config.kv_connector_extra_config["ucm_connector_config"]
config["device"] = self.rank
config["role"] = "scheduler" if role == KVConnectorRole.SCHEDULER else "worker"
head_size = vllm_config.model_config.get_head_size()
total_num_kv_heads = vllm_config.model_config.get_total_num_kv_heads()
config["kv_block_size"] = self.block_size * head_size * total_num_kv_heads * self.element_size
logger.info("init UCConnectorImpl, connector: %s", name)
self.connector = UcmConnectorFactory.create_connector(name, config)
else:
Expand Down
4 changes: 4 additions & 0 deletions unifiedcache/ucm_connector/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,7 @@ def create_connector(
"UcmOceanStore",
"unifiedcache.ucm_connector.ucm_oceanstor",
"UcmOceanStore")
UcmConnectorFactory.register_connector(
"UcmDram",
"unifiedcache.ucm_connector.ucm_dram",
"UcmDram")
Loading