Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions examples/ucm_config_example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@
# for backward compatibility.

# Connector name (e.g., "UcmNfsStore", "UcmDramStore")
ucm_connector_name: "UcmNfsStore"

# Connector-specific configuration
ucm_connector_config:
storage_backends: "/mnt/test"
transferIoDirect: false
ucm_connectors:
- ucm_connector_name: "UcmNfsStore"
ucm_connector_config:
storage_backends: "/mnt/test"
use_direct: false

load_only_first_rank: false

Expand Down
142 changes: 77 additions & 65 deletions ucm/integration/vllm/ucm_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,37 +56,21 @@ class RequestHasher:

_SEED_HASH = None

def __init__(self):
if RequestHasher._SEED_HASH is None:
RequestHasher._SEED_HASH = self._md5("UCM_HASH_SEED")

@staticmethod
def _md5(input_data) -> int:
input_bytes = pickle.dumps(input_data, protocol=pickle.HIGHEST_PROTOCOL)
md5_bytes = hashlib.md5(input_bytes).digest()
return int.from_bytes(md5_bytes, byteorder="big")

def __call__(self, block_size: int, request: "Request") -> list[str]:
token_ids = request.all_token_ids

ret = []
parent_block_hash_value = None
for start in range(0, len(token_ids), block_size):
end = start + block_size
block_token_ids = token_ids[start:end]
# Do not hash the block if it is not full.
if len(block_token_ids) < block_size:
break
def __init__(self, vllm_config, rank_id):
meta = f"{vllm_config.model_config.model}:{vllm_config.parallel_config.world_size}:{vllm_config.model_config.dtype}:{rank_id}"
self.meta_bytes = meta.encode("utf-8")

if not parent_block_hash_value:
parent_block_hash_value = RequestHasher._SEED_HASH
if RequestHasher._SEED_HASH is None:
RequestHasher._SEED_HASH = self("UCM_HASH_SEED")

block_token_ids_tuple = tuple(block_token_ids)
hash_value = self._md5((parent_block_hash_value, block_token_ids_tuple))
parent_block_hash_value = hash_value
ret.append(str(hash_value))
def __call__(self, input_data) -> int:
if isinstance(input_data, str):
input_bytes = input_data.encode("utf-8")
else:
input_bytes = pickle.dumps(input_data, protocol=pickle.HIGHEST_PROTOCOL)

return ret
h = hashlib.md5(self.meta_bytes + input_bytes)
return int.from_bytes(h.digest(), byteorder="big")


class UCMDirectConnector(KVConnectorBase_V1):
Expand Down Expand Up @@ -114,15 +98,18 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
torch_dev = torch.npu
dev_name = "npu"
else:
raise RuntimeError("Unsupported device platform for LMCache engine.")
raise RuntimeError("Unsupported device platform for UCMDirectConnector.")

if self.rank >= 0:
self.device = torch_dev.device(f"{dev_name}:{self.rank}")
self._layer_offset_cache = {}

self.store: UcmKVStoreBase

self.request_hasher = RequestHasher()
if role == KVConnectorRole.SCHEDULER:
self.request_hasher = RequestHasher(vllm_config, 0)
else:
self.request_hasher = RequestHasher(vllm_config, self.rank)

# save block info, avoid hash request twice, and track them until request finished
self.requests_meta: dict[str, RequestMeta] = {}
Expand All @@ -139,41 +126,60 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
self.broadcast_fn = self.group_coordinator.broadcast
self.broadcast_stream = torch.cuda.Stream()

if "ucm_connector_name" in self.launch_config:
name = self.launch_config.get("ucm_connector_name")
config = self.launch_config.get("ucm_connector_config") or {}
config["device"] = self.rank
config["role"] = (
"scheduler" if role == KVConnectorRole.SCHEDULER else "worker"
)
element_size = vllm_config.model_config.dtype.itemsize
single_head_dim = vllm_config.model_config.get_head_size()
num_head_per_tp = vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config
)
total_tp_size = vllm_config.parallel_config.tensor_parallel_size
num_layers = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config
)
block_size_per_layer = self.block_size * element_size * single_head_dim
config["kv_block_size"] = (
block_size_per_layer
* num_layers
* (1 if self.is_mla else num_head_per_tp * total_tp_size * 2)
)
config["io_size"] = block_size_per_layer * (
1 if self.is_mla else num_head_per_tp
)
self.store = UcmConnectorFactory.create_connector(name, config)
connector_configs = self.launch_config.get("ucm_connectors", [])
assert len(connector_configs) > 0, "no storage connector name in config."

name = connector_configs[0].get("ucm_connector_name")
config = connector_configs[0].get("ucm_connector_config") or {}
config["device"] = self.rank
config["role"] = "scheduler" if role == KVConnectorRole.SCHEDULER else "worker"
element_size = vllm_config.model_config.dtype.itemsize
single_head_dim = vllm_config.model_config.get_head_size()
num_head_per_tp = vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config
)
total_tp_size = vllm_config.parallel_config.tensor_parallel_size
num_layers = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config
)
block_size_per_layer = self.block_size * element_size * single_head_dim
config["kv_block_size"] = (
block_size_per_layer
* num_layers
* (1 if self.is_mla else num_head_per_tp * 2)
)
config["io_size"] = block_size_per_layer * (
1 if self.is_mla else num_head_per_tp
)
self.store = UcmConnectorFactory.create_connector(name, config)

logger.info("init UCConnectorImpl, connector: %s", name)
logger.info(
"single file size = %d MB, io_size = %d KB,",
config["kv_block_size"] / 1024 / 1024,
config["io_size"] / 1024,
)

def generate_hash(self, block_size: int, request: "Request") -> list[str]:
token_ids = request.all_token_ids

ret = []
parent_block_hash_value = RequestHasher._SEED_HASH
for start in range(0, len(token_ids), block_size):
end = start + block_size
block_token_ids = token_ids[start:end]
# Do not hash the block if it is not full.
if len(block_token_ids) < block_size:
break

logger.info("init UCConnectorImpl, connector: %s", name)
logger.info(
"single file size = %d MB, io_size = %d KB,",
config["kv_block_size"] / 1024 / 1024,
config["io_size"] / 1024,
block_token_ids_tuple = tuple(block_token_ids)
hash_value = self.request_hasher(
(parent_block_hash_value, block_token_ids_tuple)
)
else:
raise TypeError(f"no storage connector name in config.")
parent_block_hash_value = hash_value
ret.append(str(hash_value))

return ret

def get_num_new_matched_tokens(
self,
Expand All @@ -184,7 +190,7 @@ def get_num_new_matched_tokens(
assert num_computed_tokens % self.block_size == 0
hbm_hit_block_num = num_computed_tokens // self.block_size

ucm_block_ids = self.request_hasher(self.block_size, request)
ucm_block_ids = self.generate_hash(self.block_size, request)

external_block_ids = ucm_block_ids[hbm_hit_block_num:]
if not external_block_ids:
Expand All @@ -210,7 +216,7 @@ def get_num_new_matched_tokens(
# When all the tokens are cached in ssd or hbm,
# we need to recompute the last token. This if condition will be removed
# once vLLM scheduler provides a better solution in the future.
if external_hit_tokens == request.num_prompt_tokens:
if total_hit_block_num * self.block_size == request.num_tokens:
external_hit_tokens -= 1

self.requests_meta[request.request_id] = RequestMeta(
Expand Down Expand Up @@ -449,6 +455,9 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
continue

ucm_block_ids, vllm_block_ids = request.load_block_ids
if self.rank != 0 and not self.is_mla:
for i, ucm_block_id in enumerate(ucm_block_ids):
ucm_block_ids[i] = str(self.request_hasher(ucm_block_id))
ucm_total_block_ids, ucm_offsets, dst_tensor_addr = self._generate_task(
vllm_block_ids, ucm_block_ids
)
Expand Down Expand Up @@ -495,6 +504,9 @@ def wait_for_save(self) -> None:
continue

ucm_block_ids, vllm_block_ids = request.dump_block_ids
if self.rank != 0:
for i, ucm_block_id in enumerate(ucm_block_ids):
ucm_block_ids[i] = str(self.request_hasher(ucm_block_id))
rets = self.store.create(ucm_block_ids)
end = 0
for i, ret in enumerate(rets):
Expand Down
2 changes: 1 addition & 1 deletion ucm/store/nfsstore/nfsstore_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, config: Dict):
if transfer_enable:
param.transferDeviceId = config["device"]
param.transferIoSize = config["io_size"]
param.transferIoDirect = config.get("transferIoDirect", False)
param.transferIoDirect = config.get("use_direct", False)

# NOTE: compatible with legacy nfsstore lib
if hasattr(param, "storageCapacity"):
Expand Down