Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,23 @@ def __need_logits(seqs: SeqList):
"""Need logits."""
return any(seq.return_logits for seq in seqs)

def __need_schedule_again(prefill: bool, scheduler_output):
"""Need schedule again."""
# only reschedule when prefill
if not prefill:
return False
# schedule decoding if no valid prefill reqs.
if len(scheduler_output.running) > 0:
return False
# disable decoding for prefill role
if (self.engine_config.role == EngineRole.Prefill):
return False
# disable decoding if no running reqs.
if not self.scheduler.has_running():
logger.warning('No running sequences for decoding scheduling after prefill scheduling.')
return False
return True

scheduler = self.scheduler
logger.debug(f'Make forward inputs with prefill={prefill}, enable_empty={enable_empty}')

Expand All @@ -905,8 +922,7 @@ def __need_logits(seqs: SeqList):
if enable_empty and len(scheduler_output.running) == 0:
return None

# schedule decoding if no valid prefill reqs.
if prefill and len(scheduler_output.running) == 0 and self.engine_config.role != EngineRole.Prefill:
if __need_schedule_again(prefill, scheduler_output):
prefill = False
prealloc_size = self.engine_strategy.get_prealloc_size(not prefill)
scheduler_output = scheduler.schedule(is_prefill=prefill, prealloc_size=prealloc_size)
Expand Down
20 changes: 0 additions & 20 deletions lmdeploy/pytorch/paging/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Dict, List

from lmdeploy.messages import EventType, ScheduleMetrics
from lmdeploy.pytorch.disagg.config import EngineRole
from lmdeploy.utils import get_logger, logging_timer

from ..config import CacheConfig, SchedulerConfig
Expand Down Expand Up @@ -198,16 +197,6 @@ def _schedule_prefill(self, prealloc_size: int = 0):
running: SeqList = []
token_count = 0

def _get_free_ratio():
num_free_blocks = self.block_manager.get_num_free_gpu_blocks()
num_all_blocks = self.cache_config.num_gpu_blocks
free_ratio = num_free_blocks / num_all_blocks
return free_ratio

def __evict_block_trie():
num_req = int(self.cache_config.num_gpu_blocks * 0.1) - self.block_manager.get_num_free_gpu_blocks() + 1
self.block_trie.evict(num_req)

def _to_running(seq: SchedulerSequence):
"""To running."""
seq.status = MessageStatus.RUNNING
Expand All @@ -231,15 +220,6 @@ def _reorder_waiting():
if (len(running) >= max_batches or num_waiting == 0):
return running, swap_in_map, swap_out_map, copy_map

# reserve some blocks for decoding to avoid too much eviction
if self.cache_config.role != EngineRole.Prefill:
free_ratio = _get_free_ratio()
if free_ratio < 0.1:
__evict_block_trie()
free_ratio = _get_free_ratio()
if free_ratio < 0.1:
return running, swap_in_map, swap_out_map, copy_map

waiting = _reorder_waiting()
while len(waiting) > 0 and len(running) < max_batches:
seq = waiting.pop(0)
Expand Down