Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/parallax/sglang/batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,17 @@ def form_sgl_batch_prefill(
) -> ForwardBatch:
"""Initialize a prefill ScheduleBatch -> ModelWorkerBatch -> ForwardBatch workflow"""
sgl_reqs = transform_requests_to_sglang(requests)

def dummy_evict(*args):
pass

dummy_tree_cache = SimpleNamespace(
page_size=model_runner.server_args.page_size,
device=model_runner.device,
token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
evictable_size=0,
)
dummy_tree_cache.evict = dummy_evict
schedule_batch = ScheduleBatch.init_new(
reqs=sgl_reqs,
req_to_token_pool=model_runner.req_to_token_pool,
Expand Down Expand Up @@ -116,6 +122,7 @@ def select_batch(
ret.reqs = [origin_batch.reqs[i] for i in keep_indices]
if origin_batch.multimodal_inputs is not None:
ret.multimodal_inputs = [origin_batch.multimodal_inputs[i] for i in keep_indices]
ret.seq_lens_cpu = origin_batch.seq_lens_cpu[keep_indices]
ret.req_pool_indices = origin_batch.req_pool_indices[keep_indices_device]
ret.seq_lens = origin_batch.seq_lens[keep_indices_device]
ret.orig_seq_lens = origin_batch.orig_seq_lens[keep_indices_device]
Expand Down Expand Up @@ -189,6 +196,7 @@ def form_sgl_batch_decode(
# TODO: this is a hack to make the seq_lens correct due to select_batch is not refference running batch's seq_lens
# need to fix this
running_batch.seq_lens[ready_indices] += 1
running_batch.seq_lens_cpu[ready_indices] += 1
running_batch.orig_seq_lens[ready_indices] += 1

model_worker_batch = ret.get_model_worker_batch()
Expand Down