Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions ucm/integration/vllm/uc_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
KVConnectorRole,
)
from vllm.distributed.parallel_state import get_world_group
from vllm.v1.core.kv_cache_utils import hash_request_tokens
from vllm.v1.core.sched.output import SchedulerOutput

from ucm.logger import init_logger
Expand Down Expand Up @@ -545,10 +544,32 @@ def md5(input) -> int:
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
md5_bytes = hashlib.md5(input_bytes).digest()
return int.from_bytes(md5_bytes, byteorder="big")

def hash_request_tokens(hash_function: Any, block_size: int,
request: "Request") -> list[str]:
token_ids = request.all_token_ids

ret = []
parent_block_hash_value = None
for start in range(0, len(token_ids), block_size):
end = start + block_size
block_token_ids = token_ids[start:end]
# Do not hash the block if it is not full.
if len(block_token_ids) < block_size:
break

if not parent_block_hash_value:
parent_block_hash_value = md5("UCMHASHSEED")

block_token_ids_tuple = tuple(block_token_ids)
hash_value = hash_function((parent_block_hash_value, block_token_ids_tuple))
parent_block_hash_value = hash_value
ret.append(str(hash_value))

return ret

assert num_computed_tokens % self.block_size == 0
block_hash_types = hash_request_tokens(md5, self.block_size, request)
block_hashes: List[str] = [str(x.hash_value) for x in block_hash_types]
block_hashes = hash_request_tokens(md5, self.block_size, request)
if not block_hashes:
logger.debug("Maybe tokens too short to load.")
return 0, False
Expand Down