diff --git a/src/parallax/sglang/batch_info.py b/src/parallax/sglang/batch_info.py index 18ac7df7..9a9ff9aa 100755 --- a/src/parallax/sglang/batch_info.py +++ b/src/parallax/sglang/batch_info.py @@ -68,11 +68,17 @@ def form_sgl_batch_prefill( ) -> ForwardBatch: """Initialize a prefill ScheduleBatch -> ModelWorkerBatch -> ForwardBatch workflow""" sgl_reqs = transform_requests_to_sglang(requests) + + def dummy_evict(*args): + pass + dummy_tree_cache = SimpleNamespace( page_size=model_runner.server_args.page_size, device=model_runner.device, token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator, + evictable_size=0, ) + dummy_tree_cache.evict = dummy_evict schedule_batch = ScheduleBatch.init_new( reqs=sgl_reqs, req_to_token_pool=model_runner.req_to_token_pool, @@ -116,6 +122,7 @@ def select_batch( ret.reqs = [origin_batch.reqs[i] for i in keep_indices] if origin_batch.multimodal_inputs is not None: ret.multimodal_inputs = [origin_batch.multimodal_inputs[i] for i in keep_indices] + ret.seq_lens_cpu = origin_batch.seq_lens_cpu[keep_indices] ret.req_pool_indices = origin_batch.req_pool_indices[keep_indices_device] ret.seq_lens = origin_batch.seq_lens[keep_indices_device] ret.orig_seq_lens = origin_batch.orig_seq_lens[keep_indices_device] @@ -189,6 +196,7 @@ def form_sgl_batch_decode( # TODO: this is a hack to make the seq_lens correct due to select_batch is not refference running batch's seq_lens # need to fix this running_batch.seq_lens[ready_indices] += 1 + running_batch.seq_lens_cpu[ready_indices] += 1 running_batch.orig_seq_lens[ready_indices] += 1 model_worker_batch = ret.get_model_worker_batch()