Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions ucm/integration/vllm/ucm_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ class RequestMeta:
hbm_hit_block_num: int = 0
# local_computed_block + external_computed_block
total_hit_block_num: int = 0
num_token_ids: int = 0
vllm_block_ids: list[int] = field(default_factory=list)
token_processed: int = 0


@dataclass
Expand Down Expand Up @@ -207,7 +210,6 @@ def get_num_new_matched_tokens(
request: "Request",
num_computed_tokens: int,
) -> tuple[int, bool]:

assert num_computed_tokens % self.block_size == 0
hbm_hit_block_num = num_computed_tokens // self.block_size

Expand Down Expand Up @@ -242,13 +244,16 @@ def get_num_new_matched_tokens(
# When all the tokens are cached in ssd or hbm,
# we need to recompute the last token. This if condition will be removed
# once vLLM scheduler provides a better solution in the future.
if total_hit_block_num * self.block_size == request.num_tokens:
num_total_hit_tokens = total_hit_block_num * self.block_size
if num_total_hit_tokens == request.num_tokens:
external_hit_tokens -= 1

self.requests_meta[request.request_id] = RequestMeta(
ucm_block_ids=ucm_block_ids,
hbm_hit_block_num=hbm_hit_block_num,
total_hit_block_num=total_hit_block_num,
num_token_ids=len(request.all_token_ids),
token_processed=num_total_hit_tokens,
)

return external_hit_tokens, False
Expand Down Expand Up @@ -277,22 +282,16 @@ def _generate_dispatch_meta(
| scheduled_block_num |
"""

new_blocks_num = new_tokens // self.block_size
hbm_hit_block_num = req_meta.hbm_hit_block_num
total_hit_block_num = req_meta.total_hit_block_num
scheduled_block_num = total_hit_block_num + new_blocks_num
ucm_block_ids = req_meta.ucm_block_ids
req_meta.vllm_block_ids.extend(vllm_block_ids)

dump_ucm_block_ids = ucm_block_ids[total_hit_block_num:scheduled_block_num]
if need_load:
dump_vllm_block_ids = vllm_block_ids[
total_hit_block_num:scheduled_block_num
]
else:
dump_vllm_block_ids = vllm_block_ids

# after this round, req_meta will be updated
req_meta.total_hit_block_num = scheduled_block_num
start_idx = req_meta.token_processed // self.block_size
end_idx = (req_meta.token_processed + new_tokens) // self.block_size
dump_ucm_block_ids = ucm_block_ids[start_idx:end_idx]
dump_vllm_block_ids = req_meta.vllm_block_ids[start_idx:end_idx]
req_meta.token_processed += new_tokens

load_ucm_block_ids, load_vllm_block_ids = [], []
if need_load:
Expand Down Expand Up @@ -332,10 +331,13 @@ def build_connector_meta(
continue
req_meta = self.requests_meta.get(request_id)
if req_meta:
new_block_ids = []
if scheduled_cached_reqs.new_block_ids[i] != None:
new_block_ids = scheduled_cached_reqs.new_block_ids[i][0]
requests_dispatch_meta[request_id] = self._generate_dispatch_meta(
req_meta,
scheduler_output.num_scheduled_tokens[request_id],
scheduled_cached_reqs.new_block_ids[i][0],
new_block_ids,
scheduled_cached_reqs.resumed_from_preemption[i],
)
else:
Expand Down