diff --git a/ucm/integration/vllm/uc_connector.py b/ucm/integration/vllm/uc_connector.py index b6657e90..9e793485 100644 --- a/ucm/integration/vllm/uc_connector.py +++ b/ucm/integration/vllm/uc_connector.py @@ -37,7 +37,6 @@ KVConnectorRole, ) from vllm.distributed.parallel_state import get_world_group -from vllm.v1.core.kv_cache_utils import hash_request_tokens from vllm.v1.core.sched.output import SchedulerOutput from ucm.logger import init_logger @@ -545,10 +544,32 @@ def md5(input) -> int: input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) md5_bytes = hashlib.md5(input_bytes).digest() return int.from_bytes(md5_bytes, byteorder="big") + + def hash_request_tokens(hash_function: Any, block_size: int, + request: "Request") -> list[str]: + token_ids = request.all_token_ids + + ret = [] + parent_block_hash_value = None + for start in range(0, len(token_ids), block_size): + end = start + block_size + block_token_ids = token_ids[start:end] + # Do not hash the block if it is not full. + if len(block_token_ids) < block_size: + break + + if not parent_block_hash_value: + parent_block_hash_value = md5("UCMHASHSEED") + + block_token_ids_tuple = tuple(block_token_ids) + hash_value = hash_function((parent_block_hash_value, block_token_ids_tuple)) + parent_block_hash_value = hash_value + ret.append(str(hash_value)) + + return ret assert num_computed_tokens % self.block_size == 0 - block_hash_types = hash_request_tokens(md5, self.block_size, request) - block_hashes: List[str] = [str(x.hash_value) for x in block_hash_types] + block_hashes = hash_request_tokens(md5, self.block_size, request) if not block_hashes: logger.debug("Maybe tokens too short to load.") return 0, False