Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions fastdeploy/engine/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class RequestType(Enum):
PREFILL = 0
DECODE = 1
PREEMPTED = 2
EXTEND = 3


@dataclass
Expand Down Expand Up @@ -141,6 +142,9 @@ def __init__(
self.task_type = RequestType.PREFILL
self.idx = None
self.need_prefill_tokens = self.prompt_token_ids_len
# extend block tables
self.use_extend_tables = False
self.extend_block_tables = []

@classmethod
def from_dict(cls, d: dict):
Expand Down
74 changes: 74 additions & 0 deletions fastdeploy/engine/sched/resource_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,18 @@ class ScheduledPreemptTask:
task_type: RequestType = RequestType.PREEMPTED


@dataclass
class ScheduledExtendBlocksTask:
"""
Task for allocating new blocks to extend.
"""

idx: int
request_id: str
extend_block_tables: list[int]
task_type: RequestType = RequestType.EXTEND


class ResourceManagerV1(ResourceManager):
"""
Resource manager for scheduler v1.
Expand All @@ -80,6 +92,8 @@ def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, l
self.to_be_rescheduled_request_id_set = set()
main_process_metrics.max_batch_size.set(max_num_seqs)

self.using_extend_tables_req_id = set()

def allocated_slots(self, request: Request):
return len(request.block_tables) * self.config.cache_config.block_size

Expand Down Expand Up @@ -405,6 +419,57 @@ def schedule(self):
break
else:
llm_logger.error("Unknown request status type")

# schedule when extend block tables is needed
for req in self.running:
num_prefill_blocks = req.need_prefill_tokens // self.config.cache_config.block_size
# alocate
if req.use_extend_tables and req.request_id not in self.using_extend_tables_req_id:
llm_logger.info(
f"req {req.request_id} at batch id {req.idx} with num_prefill_blocks {num_prefill_blocks} is going to enable extend tables"
)
self.using_extend_tables_req_id.add(req.request_id)
if self.cache_manager.can_allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num):
req.extend_block_tables = req.block_tables[:num_prefill_blocks] # copy prompt cache
req.extend_block_tables.extend(
self.cache_manager.allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num)
)
scheduled_reqs.append(
ScheduledExtendBlocksTask(
idx=req.idx, request_id=req.request_id, extend_block_tables=req.extend_block_tables
)
)
llm_logger.info(f"extend blocks is {req.extend_block_tables}")
else:
continue
# recycle
elif not req.use_extend_tables and req.request_id in self.using_extend_tables_req_id:
llm_logger.info(f"req {req.request_id} is going to disable extend tables")
self.using_extend_tables_req_id.remove(req.request_id)
self.cache_manager.recycle_gpu_blocks(req.extend_block_tables[num_prefill_blocks:])
req.extend_block_tables = []

# allocate extend blocks when blocks is going to exhaust
elif req.request_id in self.using_extend_tables_req_id:
if (
self.allocated_slots(req) - req.num_total_tokens
<= self.config.cache_config.prealloc_dec_block_slot_num_threshold
):
llm_logger.info(
f"req {req.request_id} is going to alocate more extend tables because allocated_slots {self.allocated_slots(req)} and prealloc_dec_block_slot_num_threshold {self.config.cache_config.prealloc_dec_block_slot_num_threshold} req.num_total_tokens {req.num_total_tokens}"
)
if self.cache_manager.can_allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num):
req.extend_block_tables.extend(
self.cache_manager.allocate_gpu_blocks(self.config.cache_config.enc_dec_block_num)
)
scheduled_reqs.append(
ScheduledExtendBlocksTask(
idx=req.idx, request_id=req.request_id, extend_block_tables=req.extend_block_tables
)
)
else:
continue

if scheduled_reqs:
task_used_block_num = sum([len(task.block_tables) if task else 0 for task in self.tasks_list])
main_process_metrics.available_gpu_block_num.set(self.total_block_number() - task_used_block_num)
Expand Down Expand Up @@ -488,6 +553,15 @@ def _free_blocks(self, request: Request):
self.cache_manager.recycle_gpu_blocks(request.block_tables)
request.block_tables = []

if request.request_id in self.using_extend_tables_req_id:
num_prefill_blocks = request.need_prefill_tokens // self.config.cache_config.block_size
self.using_extend_tables_req_id.remove(request.request_id)
self.cache_manager.recycle_gpu_blocks(request.extend_block_tables[num_prefill_blocks:])
llm_logger.info(
f"req {request.request_id} recycle extend blocks {request.extend_block_tables[num_prefill_blocks:]}"
)
request.extend_block_tables = []

def finish_requests_async(self, request_ids: Union[str, Iterable[str]]):
return self.finish_execution_pool.submit(self.finish_requests, request_ids)

Expand Down
Loading