diff --git a/pyproject.toml b/pyproject.toml index 633c880b..efa3f93d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,9 +47,15 @@ mac = [ ] gpu = [ + "sglang[all]==0.5.4.post1", + "mlx-lm==0.28.0", + "mlx[cpu]==0.29.1", +] + +vllm = [ + "vllm==0.11.0", "mlx-lm==0.28.0", "mlx[cpu]==0.29.1", - "sglang[all]==0.5.4.post1", ] benchmark = [ diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index ea3c5d96..d2cda590 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -73,6 +73,8 @@ def __init__( start_layer: int, end_layer: int, dtype: str = "float16", + # Backend selection + gpu_backend: str = "sglang", use_hfcache: bool = False, # Scheduler Configs max_batch_size: Optional[int] = 8, @@ -111,38 +113,64 @@ def __init__( self.device = get_current_device() self.use_hfcache = use_hfcache logger.debug(f"Executor initializing on device: {self.device}") + self.backend_type = gpu_backend # Sharded Model if self.device == "cuda": - from sglang.srt.managers.schedule_batch import ScheduleBatch + if self.backend_type == "vllm": + from parallax.vllm.model_runner import ( + initialize_vllm_model_runner as initialize_cuda_model_runner, + ) - from parallax.sglang.model_runner import initialize_sgl_model_runner + logger.debug( + f"Initializing vLLM model runner for repo={model_repo}, layers=[{start_layer}, {end_layer})" + ) + elif self.backend_type == "sglang": + from sglang.srt.managers.schedule_batch import ( + ScheduleBatch as SGLangScheduleBatch, + ) - logger.debug( - f"Initializing CUDA model runner for repo={model_repo}, layers=[{start_layer}, {end_layer})" - ) - self.model_runner, self.config, self.tokenizer = initialize_sgl_model_runner( - model_repo, - start_layer, - end_layer, - kv_cache_memory_fraction, - attention_backend, - kv_block_size, - moe_runner_backend, - tp_rank, - tp_size, - nccl_port, - use_hfcache=self.use_hfcache, + from parallax.sglang.model_runner import ( + initialize_sgl_model_runner as initialize_cuda_model_runner, + ) + + logger.debug( + f"Initializing SGLang model runner for repo={model_repo}, layers=[{start_layer}, {end_layer})" + ) + else: + raise ValueError(f"Unsupported GPU backend type: {self.backend_type}") + + # Prepare all parameters for model runner initialization + model_runner_params = { + "model_repo": model_repo, + "start_layer": start_layer, + "end_layer": end_layer, + "kv_cache_memory_fraction": kv_cache_memory_fraction, + "attention_backend": attention_backend, + "kv_block_size": kv_block_size, + "max_num_tokens_per_batch": max_num_tokens_per_batch, + "dtype": dtype, + "moe_runner_backend": moe_runner_backend, + "tp_rank": tp_rank, + "tp_size": tp_size, + "nccl_port": nccl_port, + "using_hfcache": use_hfcache, + } + + self.model_runner, self.config, self.tokenizer = initialize_cuda_model_runner( + **model_runner_params ) logger.debug( f"CUDA model runner initialized. num_layers={self.config.get('num_hidden_layers')}" ) - self.tp_group = self.model_runner.tp_group - self.tp_cpu_group = self.tp_group.cpu_group - # SGL KV Cache Manager is already initialized in ScheduleBatch - # TODO: Replace ScheduleBatch to Parallax inflight batch - self.running_batch = ScheduleBatch(reqs=[], batch_is_full=False) self.cur_batch = None + self.running_batch = None + + if self.backend_type == "sglang": + self.running_batch = SGLangScheduleBatch(reqs=[], batch_is_full=False) + self.tp_group = self.model_runner.tp_group + self.tp_cpu_group = self.tp_group.cpu_group + else: logger.debug( f"Initializing MLX sharded model loader for repo={model_repo}, layers=[{start_layer}, {end_layer})" @@ -419,107 +447,195 @@ def recv_requests_from_peer(self) -> List[Request]: def _prepare_cuda_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, Any]: """ - Prepares inputs for SGLang model runner from a batch of prefill requests. - Returns: SGLang ScheduleBatch + Prepares inputs for CUDA backends from a batch of prefill requests. + Routes to SGLang or vLLM depending on backend_type. """ from sglang.srt.model_executor.forward_batch_info import PPProxyTensors - from parallax.sglang.batch_info import form_sgl_batch_prefill - batch_size = len(batched_requests) if batch_size == 0: return None - schedule_batch, forward_batch = form_sgl_batch_prefill(batched_requests, self.model_runner) - self.cur_batch = schedule_batch + # Prepare PP proxy tensors (common for both backends when not first peer) pp_proxy_tensors = None if not self.is_first_peer: - hidden_states = torch.cat( - [ - ( - req.hidden_states - if req.hidden_states.ndim == 2 - else req.hidden_states.unsqueeze(0) - ) - for req in batched_requests - ], - dim=0, - ) + # Concatenate hidden states from all requests + # For vLLM, we need to flatten to (total_tokens, hidden_size) + # For SGLang, we keep the batch dimension + hidden_states_list = [] + for req in batched_requests: + hs = req.hidden_states + if hs.ndim == 2: + # Already (seq_len, hidden_size) or (1, hidden_size) + hidden_states_list.append(hs) + elif hs.ndim == 3: + # (1, seq_len, hidden_size) -> (seq_len, hidden_size) + hidden_states_list.append(hs.squeeze(0)) + else: + # (hidden_size,) -> (1, hidden_size) + hidden_states_list.append(hs.unsqueeze(0)) + + # Concatenate along sequence dimension to get (total_tokens, hidden_size) + hidden_states = torch.cat(hidden_states_list, dim=0) + + # Create residual tensor with same shape residual = torch.zeros( hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device ) - pp_proxy_tensors = PPProxyTensors( - { - "hidden_states": hidden_states, - "residual": residual, - } - ) + + if self.backend_type == "vllm": + # For vLLM, pass directly as IntermediateTensors + from vllm.sequence import IntermediateTensors + + pp_proxy_tensors = IntermediateTensors( + { + "hidden_states": hidden_states, + "residual": residual, + } + ) + else: + # For SGLang, use PPProxyTensors + pp_proxy_tensors = PPProxyTensors( + { + "hidden_states": hidden_states, + "residual": residual, + } + ) logger.debug(f"PP Proxy: hidden_states shape: {hidden_states.shape}") + + # Prepare lengths (common for both backends) lengths = [] for req in batched_requests: lengths.append(req.total_length) - ret = { - "forward_batch": forward_batch, - "pp_proxy_tensors": pp_proxy_tensors, - "lengths": torch.tensor(lengths, device=self.device), - "requests": batched_requests, - } - logger.debug(f"Prepared CUDA prefill batch (size={batch_size})") - return ret + lengths_tensor = torch.tensor(lengths, device=self.device) + + if self.backend_type == "vllm": + from parallax.vllm.batch_info import ( + compute_expected_intermediate_tokens, + form_vllm_batch_prefill, + resize_intermediate_tensors, + ) + + schedule_outputs_prefill = form_vllm_batch_prefill(batched_requests, self.model_runner) + + if not self.is_first_peer and pp_proxy_tensors is not None: + target_tokens = compute_expected_intermediate_tokens( + schedule_outputs_prefill, self.model_runner + ) + pp_proxy_tensors = resize_intermediate_tensors(pp_proxy_tensors, target_tokens) + + ret = { + "scheduler_output": schedule_outputs_prefill, + "pp_proxy_tensors": pp_proxy_tensors, + "lengths": lengths_tensor, + "requests": batched_requests, + } + logger.debug(f"Prepared CUDA prefill batch (vllm, size={batch_size})") + return ret + else: + from parallax.sglang.batch_info import form_sgl_batch_prefill + + schedule_batch, forward_batch = form_sgl_batch_prefill( + batched_requests, self.model_runner + ) + self.cur_batch = schedule_batch + + ret = { + "forward_batch": forward_batch, + "pp_proxy_tensors": pp_proxy_tensors, + "lengths": lengths_tensor, + "requests": batched_requests, + } + logger.debug(f"Prepared CUDA prefill batch (sglang, size={batch_size})") + return ret def _prepare_cuda_decode_batch(self, batched_requests: List[Request]) -> Dict[str, Any]: """ - Prepares inputs for SGLang model runner from a batch of decode requests. - Returns: SGLang ScheduleBatch + Prepares inputs for CUDA backends from a batch of decode requests. + Routes to SGLang or vLLM depending on backend_type. """ - from sglang.srt.model_executor.forward_batch_info import PPProxyTensors - - from parallax.sglang.batch_info import form_sgl_batch_decode batch_size = len(batched_requests) if batch_size == 0: return None - lengths = [] - for req in batched_requests: - lengths.append(req.total_length) - forward_batch = form_sgl_batch_decode( - batched_requests, - self.model_runner, - self.running_batch, - self.is_first_peer, - ) + # Prepare PP proxy tensors (common for both backends when not first peer) pp_proxy_tensors = None if not self.is_first_peer: - hidden_states = torch.cat( - [ - ( - req.hidden_states - if req.hidden_states.ndim == 2 - else req.hidden_states.unsqueeze(0) - ) - for req in batched_requests - ], - dim=0, - ) + # Concatenate hidden states from all requests + # For vLLM, we need to flatten to (total_tokens, hidden_size) + # For SGLang, we keep the batch dimension + hidden_states_list = [] + for req in batched_requests: + hs = req.hidden_states + if hs.ndim == 2: + # Already (seq_len, hidden_size) or (1, hidden_size) + hidden_states_list.append(hs) + elif hs.ndim == 3: + # (1, seq_len, hidden_size) -> (seq_len, hidden_size) + hidden_states_list.append(hs.squeeze(0)) + else: + # (hidden_size,) -> (1, hidden_size) + hidden_states_list.append(hs.unsqueeze(0)) + + # Concatenate along sequence dimension to get (total_tokens, hidden_size) + hidden_states = torch.cat(hidden_states_list, dim=0) + + # Create residual tensor with same shape residual = torch.zeros( hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device ) - pp_proxy_tensors = PPProxyTensors( + + if self.backend_type == "vllm": + from vllm.sequence import IntermediateTensors as CudaPPProxyTensors + else: + from sglang.srt.model_executor.forward_batch_info import ( + PPProxyTensors as CudaPPProxyTensors, + ) + pp_proxy_tensors = CudaPPProxyTensors( { "hidden_states": hidden_states, "residual": residual, } ) logger.debug(f"PP Proxy: hidden_states shape: {hidden_states.shape}") - ret = { - "forward_batch": forward_batch, - "pp_proxy_tensors": pp_proxy_tensors, - "lengths": torch.tensor(lengths, device=self.device), - "requests": batched_requests, - } - logger.debug(f"Prepared CUDA decode batch (size={batch_size})") - return ret + + # Prepare lengths (common for both backends) + lengths = [] + for req in batched_requests: + lengths.append(req.total_length) + lengths_tensor = torch.tensor(lengths, device=self.device) + + if self.backend_type == "vllm": + from parallax.vllm.batch_info import form_vllm_batch_decode + + scheduler_outputs_decode = form_vllm_batch_decode(batched_requests, self.model_runner) + ret = { + "scheduler_output": scheduler_outputs_decode, + "pp_proxy_tensors": pp_proxy_tensors, + "lengths": lengths_tensor, + "requests": batched_requests, + } + logger.debug(f"Prepared CUDA decode batch (vllm, size={batch_size})") + return ret + else: + from parallax.sglang.batch_info import form_sgl_batch_decode + + forward_batch = form_sgl_batch_decode( + batched_requests, + self.model_runner, + self.running_batch, + self.is_first_peer, + ) + + ret = { + "forward_batch": forward_batch, + "pp_proxy_tensors": pp_proxy_tensors, + "lengths": lengths_tensor, + "requests": batched_requests, + } + logger.debug(f"Prepared CUDA decode batch (sglang, size={batch_size})") + return ret def _prepare_mlx_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, Any]: """Prepares inputs for ShardedModel from a batch of prefill requests.""" @@ -812,7 +928,7 @@ def _handle_cuda_input_requests(self, requests: List[Request]): Cuda specialized handle function. The main difference is to remove all the kv cache operations. """ - from parallax.sglang.batch_info import release_cuda_request + from parallax.sglang.batch_info import release_sglang_request if self.is_first_peer: # First peer can receive InitialRequests from the client RPC, @@ -832,13 +948,22 @@ def _handle_cuda_input_requests(self, requests: List[Request]): assert req.next_token_id is not None original_req.commit_new_token(req.next_token_id) + logger.debug( + f"[FirstPeer-CUDA] Committed token {req.next_token_id} for {req.request_id}, " + f"output_ids now has {len(original_req.output_ids)} tokens" + ) if len(req.routing_table) > 0: original_req.routing_table = req.routing_table # Check for termination. if self.scheduler.check_and_update_request_status(original_req): logger.debug(f"Releasing resources for finished request {req.request_id}") - release_cuda_request(self.running_batch, req.request_id) + if self.backend_type == "sglang": + release_sglang_request(self.running_batch, req.request_id) + elif self.backend_type == "vllm": + from parallax.vllm.batch_info import release_vllm_request + + release_vllm_request(self.model_runner, req.request_id) if not self.is_last_peer: self.finished_batch.append(req) else: @@ -866,8 +991,7 @@ def _handle_cuda_input_requests(self, requests: List[Request]): req, IntermediateRequest ), "Non-first peers must receive IntermediateRequests." if req.is_finished or req.hidden_states is None: - self.scheduler.evict_request(req.request_id) - release_cuda_request(self.running_batch, req.request_id) + self._release_and_evict_request(req.request_id) if not self.is_last_peer: self.finished_batch.append(req) else: @@ -1083,38 +1207,93 @@ def _process_batch_cuda( self, prepared_inputs: Dict[str, Any], return_decoded_tokens: bool = True ): """ - Process a batch of requests in CUDA. + Process a batch of requests in CUDA, supports both vLLM and SGLang backends. """ - assert "forward_batch" in prepared_inputs, "forward_batch should be in cuda prepared inputs" - assert ( - "pp_proxy_tensors" in prepared_inputs - ), "pp_proxy_tensors should be in cuda prepared inputs" - forward_batch = prepared_inputs["forward_batch"] - pp_proxy_tensors = prepared_inputs["pp_proxy_tensors"] - logits_output, _ = self.model_runner.forward( - forward_batch=forward_batch, - pp_proxy_tensors=pp_proxy_tensors, - ) + if self.backend_type == "vllm": + # ========== vLLM Backend ========== + assert ( + "scheduler_output" in prepared_inputs + ), "scheduler_output should be provided for vLLM backend" + assert ( + "pp_proxy_tensors" in prepared_inputs + ), "pp_proxy_tensors should be in cuda prepared inputs" + scheduler_output = prepared_inputs["scheduler_output"] + pp_proxy_tensors = prepared_inputs["pp_proxy_tensors"] + # For vLLM, pp_proxy_tensors is already an IntermediateTensors object + intermediate_tensors = pp_proxy_tensors if pp_proxy_tensors is not None else None + if intermediate_tensors is not None: + logger.debug(f"vLLM: Using intermediate_tensors for PP (non-first peer)") + + # Import IntermediateTensors for type checking + + # Execute model with vLLM + output = self.model_runner.execute_model( + scheduler_output=scheduler_output, + intermediate_tensors=intermediate_tensors, + ) - if self.cur_batch: - if self.cur_batch.forward_mode.is_extend(): - # Merge the new batch into the running batch - if not self.cur_batch.is_empty(): - if self.running_batch.is_empty(): - self.running_batch = self.cur_batch - else: - # Merge running_batch with prefill batch - self.running_batch.merge_batch(self.cur_batch) - self.cur_batch = None + # Return appropriate output based on peer position + if return_decoded_tokens: + import torch + + sampled_token_ids = output.sampled_token_ids + if isinstance(sampled_token_ids, list) and len(sampled_token_ids) > 0: + # Convert to tensor: pad sequences to same length + max_len = max(len(seq) for seq in sampled_token_ids) + padded_tokens = [] + for seq in sampled_token_ids: + padded_seq = seq + [-1] * (max_len - len(seq)) # Pad with -1 + padded_tokens.append(padded_seq) + return torch.tensor(padded_tokens, dtype=torch.int64) + else: + return torch.tensor(sampled_token_ids, dtype=torch.int64) + else: + # Intermediate peer: return hidden states for next peer + final_hidden_states = output.tensors["hidden_states"] + output.tensors["residual"] + return final_hidden_states + + else: # self.backend_type == "sglang" + # ========== SGLang Backend ========== + assert ( + "forward_batch" in prepared_inputs + ), "forward_batch should be in cuda prepared inputs" + assert ( + "pp_proxy_tensors" in prepared_inputs + ), "pp_proxy_tensors should be in cuda prepared inputs" + + forward_batch = prepared_inputs["forward_batch"] + pp_proxy_tensors = prepared_inputs["pp_proxy_tensors"] + + # Execute model with SGLang + logits_output, _ = self.model_runner.forward( + forward_batch=forward_batch, + pp_proxy_tensors=pp_proxy_tensors, + ) - if return_decoded_tokens: - next_token_ids = self.model_runner.sample(logits_output, forward_batch) - return next_token_ids - # Currently hack the result of (hidden_state + residual) here for GPU - final_hidden_states = ( - logits_output.tensors["hidden_states"] + logits_output.tensors["residual"] - ) - return final_hidden_states + # SGLang-specific batch management: merge prefill batch into running batch + if self.cur_batch: + if self.cur_batch.forward_mode.is_extend(): + # Merge the new batch into the running batch + if not self.cur_batch.is_empty(): + if self.running_batch.is_empty(): + self.running_batch = self.cur_batch + else: + # Merge running_batch with prefill batch + self.running_batch.merge_batch(self.cur_batch) + self.cur_batch = None + + # Return appropriate output based on peer position + if return_decoded_tokens: + # Last peer: sample and return token IDs + next_token_ids = self.model_runner.sample(logits_output, forward_batch) + return next_token_ids + else: + # Intermediate peer: return hidden states for next peer + # Note: SGLang stores hidden_states + residual separately + final_hidden_states = ( + logits_output.tensors["hidden_states"] + logits_output.tensors["residual"] + ) + return final_hidden_states def _process_batch_mlx( self, prepared_inputs: Dict[str, Any], return_decoded_tokens: bool = True @@ -1208,10 +1387,15 @@ def _release_and_evict_request(self, rid: str): """Release per-request resources and evict from scheduler. Best-effort, never raises.""" # Release resources if self.device == "cuda": - from parallax.sglang.batch_info import release_cuda_request - try: - release_cuda_request(self.running_batch, rid) + if self.backend_type == "vllm": + from parallax.vllm.batch_info import release_vllm_request + + release_vllm_request(self.model_runner, rid) + elif self.backend_type == "sglang": + from parallax.sglang.batch_info import release_sglang_request + + release_sglang_request(self.running_batch, rid) except Exception: pass else: @@ -1385,6 +1569,7 @@ def shutdown(self): def run_executor_process(args, gradient_server=None): """Run executor as a subprocess""" + executor = None try: executor = Executor.create_from_args(args, gradient_server) executor.run_loop() @@ -1393,7 +1578,8 @@ def run_executor_process(args, gradient_server=None): except Exception as e: logger.exception(e) finally: - executor.shutdown() + if executor is not None: + executor.shutdown() def stop_executor_process(executor_process): @@ -1414,6 +1600,7 @@ def create_executor_config(args: argparse.Namespace, gradient_server=None): "start_layer": args.start_layer, "end_layer": args.end_layer, "dtype": args.dtype, + "gpu_backend": args.gpu_backend if hasattr(args, "gpu_backend") else "sglang", "max_sequence_length": args.max_sequence_length if "max_sequence_length" in args else None, "max_batch_size": args.max_batch_size if "max_batch_size" in args else None, "kv_block_size": args.kv_block_size, diff --git a/src/parallax/server/server_args.py b/src/parallax/server/server_args.py index 0d58fce8..613c0ce0 100644 --- a/src/parallax/server/server_args.py +++ b/src/parallax/server/server_args.py @@ -192,6 +192,14 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--verbose", action="store_true", help="Enable verbose logging") + parser.add_argument( + "--gpu-backend", + type=str, + default="sglang", + choices=["sglang", "vllm"], + help="GPU backend to use", + ) + parser.add_argument( "--use-hfcache", action="store_true", diff --git a/src/parallax/sglang/batch_info.py b/src/parallax/sglang/batch_info.py index 8129e8e7..cd8f3767 100755 --- a/src/parallax/sglang/batch_info.py +++ b/src/parallax/sglang/batch_info.py @@ -205,7 +205,7 @@ def form_sgl_batch_decode( return forward_batch -def release_cuda_request(running_batch: ScheduleBatch, request_id: str): +def release_sglang_request(running_batch: ScheduleBatch, request_id: str): """Release KV Cache and other resources for finished/aborted requests.""" if running_batch is None or running_batch.is_empty(): return diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index 70cab2bb..d1624a13 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -6,6 +6,7 @@ import logging import os +import random import sglang import sglang.srt.distributed.parallel_state @@ -221,17 +222,15 @@ def form_sgl_server_args( def initialize_sgl_model_runner( - original_model_path: str, + model_repo: str, start_layer: int, end_layer: int, kv_cache_memory_fraction: float, attention_backend: str, kv_block_size: int, moe_runner_backend: str, - tp_rank: int, - tp_size: int, - nccl_port: int, - use_hfcache: bool = False, + max_num_tokens_per_batch: int = 1024, + **kwargs, ): """ Creates a SGL ModelRunner object. @@ -242,6 +241,11 @@ def initialize_sgl_model_runner( """ apply_parallax_sglang_monkey_patch() + # Extract TP-related parameters from kwargs or use defaults + tp_rank = kwargs.get("tp_rank", 0) + tp_size = kwargs.get("tp_size", 1) + use_hfcache = kwargs.get("use_hfcache", False) + nccl_port = kwargs.get("nccl_port", None) # Use selective download for GPU models to save bandwidth and disk space from parallax.utils.selective_download import get_model_path_with_selective_download @@ -249,16 +253,16 @@ def initialize_sgl_model_runner( f"Downloading model with selective weight files for layers [{start_layer}, {end_layer})" ) model_path = get_model_path_with_selective_download( - original_model_path, - start_layer=start_layer, - end_layer=end_layer, - local_files_only=use_hfcache, + model_repo, start_layer=start_layer, end_layer=end_layer, local_files_only=use_hfcache ) config = load_config(model_path) tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) dtype = config.get("torch_dtype", "bfloat16") + if nccl_port is None: + nccl_port = random.randint(4000, 5000) + # Handling mxfp4 arguments quant_method = config.get("quant_method", None) quantization_config = config.get("quantization_config", None) diff --git a/src/parallax/vllm/batch_info.py b/src/parallax/vllm/batch_info.py new file mode 100644 index 00000000..95f2c442 --- /dev/null +++ b/src/parallax/vllm/batch_info.py @@ -0,0 +1,386 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +import torch +from vllm.sampling_params import SamplingParams as VLLMSamplingParams +from vllm.sampling_params import StructuredOutputsParams +from vllm.sequence import IntermediateTensors +from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput +from vllm.v1.request import Request as VLLMRequest + +from parallax.server.request import Request +from parallax.server.sampling.sampling_params import ( + SamplingParams as ParallaxSamplingParams, +) +from parallax_utils.logging_config import get_logger + +logger = get_logger(__name__) + + +def compute_expected_intermediate_tokens(scheduler_output: Any, model_runner: Any) -> Optional[int]: + """ + Estimate the padded token count expected by vLLM for this batch. + + This function computes the total number of tokens including padding that vLLM + expects for data parallel processing. + + Args: + scheduler_output: SchedulerOutput from vLLM scheduler + model_runner: The vLLM model runner instance + + Returns: + Expected total token count including padding, or None if unable to compute + """ + if scheduler_output is None: + return None + + total_tokens = getattr(scheduler_output, "total_num_scheduled_tokens", None) + if total_tokens is None: + return None + + try: + total_tokens = int(total_tokens) + except (TypeError, ValueError): + return None + + if model_runner is None: + return None + + get_num_input_tokens = getattr(model_runner, "_get_num_input_tokens", None) + get_dp_padding = getattr(model_runner, "get_dp_padding", None) + if get_num_input_tokens is None or get_dp_padding is None: + return None + + num_input_tokens = get_num_input_tokens(total_tokens) + num_pad, _ = get_dp_padding(num_input_tokens) + return num_input_tokens + num_pad + + +def pad_or_trim_tensor(tensor: torch.Tensor, target_len: int) -> torch.Tensor: + """ + Pad or trim a tensor to the target length along dimension 0. + + Args: + tensor: Input tensor to pad/trim + target_len: Target length for dimension 0. If negative, returns unchanged. + + Returns: + Tensor with dimension 0 adjusted to target_len + """ + if target_len < 0: + return tensor + current_len = tensor.shape[0] + if current_len == target_len: + return tensor + if current_len > target_len: + return tensor[:target_len] + pad_shape = (target_len - current_len,) + tensor.shape[1:] + pad = tensor.new_zeros(pad_shape) + return torch.cat((tensor, pad), dim=0) + + +def resize_intermediate_tensors( + intermediate_tensors: IntermediateTensors, target_len: Optional[int] +) -> IntermediateTensors: + """ + Resize all tensors in IntermediateTensors to match the target length. + + This is needed for vLLM pipeline parallelism when the actual token count + doesn't match the expected padded count for data parallel processing. + + Args: + intermediate_tensors: vLLM IntermediateTensors containing hidden states + target_len: Target token count. If None or negative, returns unchanged. + + Returns: + IntermediateTensors with all tensors resized to target_len + """ + if intermediate_tensors is None or target_len is None: + return intermediate_tensors + if target_len < 0: + return intermediate_tensors + + # Create a list to avoid "dictionary changed size during iteration". + for key, tensor in list(intermediate_tensors.items()): + intermediate_tensors[key] = pad_or_trim_tensor(tensor, target_len) + return intermediate_tensors + + +def transform_sampling_params_to_vllm(old_params: ParallaxSamplingParams) -> VLLMSamplingParams: + structured = ( + StructuredOutputsParams(json=old_params.json_schema) + if getattr(old_params, "json_schema", None) is not None + else None + ) + params = VLLMSamplingParams( + max_tokens=old_params.max_new_tokens, + min_tokens=old_params.min_new_tokens, + temperature=old_params.temperature, + top_p=old_params.top_p, + min_p=old_params.min_p, + top_k=old_params.top_k, + stop_token_ids=( + list(old_params.stop_token_ids) + if getattr(old_params, "stop_token_ids", None) is not None + else None + ), + ignore_eos=old_params.ignore_eos, + stop=old_params.stop_strs, + repetition_penalty=old_params.repetition_penalty, + presence_penalty=old_params.presence_penalty, + frequency_penalty=old_params.frequency_penalty, + structured_outputs=structured, + ) + return params + + +def _build_vllm_request( + req: Request, + sampling_params: VLLMSamplingParams, + model_runner: Any, + *, + include_outputs: bool, +) -> VLLMRequest: + block_hasher = getattr(model_runner, "request_block_hasher", None) + vllm_req = VLLMRequest( + request_id=req.request_id, + prompt_token_ids=getattr(req, "input_ids", None), + sampling_params=sampling_params, + pooling_params=None, + eos_token_id=getattr(req, "eos_token_id", None), + arrival_time=getattr(req, "arrival_time", 0.0), + block_hasher=block_hasher, + ) + if include_outputs: + output_ids = getattr(req, "output_ids", None) or [] + if output_ids: + vllm_req.append_output_token_ids(output_ids) + return vllm_req + + +def form_vllm_batch_prefill( + batched_requests: List[Request], + model_runner: Any = None, +) -> Optional[SchedulerOutput]: + if not batched_requests: + return None + + if not hasattr(model_runner, "kv_cache_manager"): + raise RuntimeError( + "model_runner must have kv_cache_manager initialized. " + "Call model_runner.initialize_kv_cache_manager() first." + ) + + kv_cache_manager = model_runner.kv_cache_manager + + num_common_prefix_blocks = [0] * len(model_runner.kv_cache_config.kv_cache_groups) + + created_vllm_requests: List[VLLMRequest] = [] + + new_request_data_list = [] + num_scheduled_tokens: Dict[str, int] = {} + total_tokens = 0 + + for req in batched_requests: + sampling_params = transform_sampling_params_to_vllm(req.sampling_params) + + vllm_req = _build_vllm_request(req, sampling_params, model_runner, include_outputs=False) + created_vllm_requests.append(vllm_req) + + computed_blocks, num_computed_tokens = kv_cache_manager.get_computed_blocks(vllm_req) + + prompt_token_ids = getattr(req, "input_ids", None) or [] + num_new_tokens = max(len(prompt_token_ids) - num_computed_tokens, 0) + if num_new_tokens > 0: + new_blocks = kv_cache_manager.allocate_slots( + request=vllm_req, + num_new_tokens=num_new_tokens, + num_new_computed_tokens=num_computed_tokens, + new_computed_blocks=computed_blocks if num_computed_tokens > 0 else None, + ) + + if new_blocks is None: + logger.warning(f"Cannot allocate KV cache for request {req.request_id}") + for prev_req in created_vllm_requests[:-1]: + kv_cache_manager.free(prev_req) + return None + + all_blocks = computed_blocks + new_blocks if num_computed_tokens > 0 else new_blocks + else: + all_blocks = computed_blocks + + block_ids = all_blocks.get_block_ids() + + new_req_data = NewRequestData( + req_id=req.request_id, + prompt_token_ids=req.input_ids, + mm_features=[], + sampling_params=sampling_params, + pooling_params=None, + block_ids=block_ids, + num_computed_tokens=num_computed_tokens, + lora_request=None, + prompt_embeds=None, + ) + new_request_data_list.append(new_req_data) + + scheduled_tokens = len(prompt_token_ids) + num_scheduled_tokens[req.request_id] = scheduled_tokens + total_tokens += scheduled_tokens + + scheduler_output = SchedulerOutput( + scheduled_new_reqs=new_request_data_list, + scheduled_cached_reqs=CachedRequestData.make_empty(), + num_scheduled_tokens=num_scheduled_tokens, + total_num_scheduled_tokens=total_tokens, + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=num_common_prefix_blocks, + finished_req_ids=set(), + free_encoder_mm_hashes=[], + structured_output_request_ids=[], + grammar_bitmask=None, + kv_connector_metadata=None, + ) + + return scheduler_output + + +def form_vllm_batch_decode( + batched_requests: List[Request], + model_runner: Any = None, + scheduler: Any = None, + **kwargs, +) -> Optional[SchedulerOutput]: + if not batched_requests: + return None + + if not hasattr(model_runner, "kv_cache_manager"): + raise RuntimeError( + "model_runner must have kv_cache_manager initialized. " + "Call model_runner.initialize_kv_cache_manager() first." + ) + + kv_cache_manager = model_runner.kv_cache_manager + + req_ids: List[str] = [] + resumed_from_preemption: List[bool] = [] + new_token_ids: List[List[int]] = [] + resumed_req_token_ids: List[List[int] | None] = [] + new_block_ids: List[tuple[List[int], ...] | None] = [] + num_computed_tokens: List[int] = [] + num_output_tokens: List[int] = [] + num_scheduled_tokens: Dict[str, int] = {} + + for req in batched_requests: + req_ids.append(req.request_id) + resumed_from_preemption.append(False) + + # For GPU workers (non-first peer), IntermediateRequest doesn't have output_ids + # We need to get it from vLLM's CachedRequestState in model_runner + output_ids = getattr(req, "output_ids", None) or [] + + # If this request doesn't have output_ids (IntermediateRequest case), + # try to get it from model_runner's cached request state (vLLM internal state) + if not output_ids and hasattr(model_runner, "requests"): + cached_req_state = model_runner.requests.get(req.request_id) + if cached_req_state is not None: + output_ids = getattr(cached_req_state, "output_token_ids", []) + logger.debug( + f"[Decode] Retrieved output_token_ids from vLLM CachedRequestState for " + f"{req.request_id}: len={len(output_ids)}" + ) + + # Fallback: try scheduler if available + if not output_ids and scheduler is not None: + running_req = scheduler.get_running_request(req.request_id) + if running_req is not None: + output_ids = getattr(running_req, "output_ids", None) or [] + logger.debug( + f"[Decode] Retrieved output_ids from scheduler for {req.request_id}: " + f"len={len(output_ids)}" + ) + + if output_ids: + last_token = output_ids[-1] + new_token_ids.append([last_token]) + else: + new_token_ids.append([]) + + resumed_req_token_ids.append([]) + + sampling_params = transform_sampling_params_to_vllm(req.sampling_params) + vllm_req = _build_vllm_request(req, sampling_params, model_runner, include_outputs=True) + + prompt_ids = getattr(req, "input_ids", None) or [] + # For decode stage, computed_token_count should be the total number of tokens + # that have been processed (including all output tokens). + # In pipeline parallelism, this must match what GPU worker expects. + if output_ids: + # All tokens (prompt + all generated outputs) have been computed + computed_token_count = len(prompt_ids) + len(output_ids) - 1 + else: + # First decode step: only prompt has been computed + computed_token_count = len(prompt_ids) + vllm_req.num_computed_tokens = computed_token_count + + # Debug logging to track state synchronization + logger.debug( + f"[Decode] req_id={req.request_id}, prompt_len={len(prompt_ids)}, " + f"output_len={len(output_ids)}, computed_tokens={computed_token_count}" + ) + + new_blocks = kv_cache_manager.allocate_slots( + request=vllm_req, + num_new_tokens=1, + num_new_computed_tokens=0, + ) + + if new_blocks is None: + logger.warning(f"Cannot allocate KV cache for decode request {req.request_id}") + return None + + new_block_ids.append(new_blocks.get_block_ids(allow_none=True)) + num_computed_tokens.append(computed_token_count) + num_output_tokens.append(len(output_ids)) + num_scheduled_tokens[req.request_id] = 1 + + cached_req_data = CachedRequestData( + req_ids=req_ids, + resumed_from_preemption=resumed_from_preemption, + new_token_ids=new_token_ids, + new_block_ids=new_block_ids, + num_computed_tokens=num_computed_tokens, + ) + + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=cached_req_data, + num_scheduled_tokens=num_scheduled_tokens, + total_num_scheduled_tokens=sum(num_scheduled_tokens.values()), + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=[0] * getattr(kv_cache_manager, "num_kv_cache_groups", 1), + finished_req_ids=set(), + free_encoder_mm_hashes=[], + structured_output_request_ids=[], + grammar_bitmask=None, + kv_connector_metadata=None, + ) + + return scheduler_output + + +def release_vllm_request(model_runner: Any, request_id: str): + if not hasattr(model_runner, "kv_cache_manager"): + logger.warning(f"KV cache manager not found when releasing request {request_id}") + return + + kv_cache_manager = model_runner.kv_cache_manager + + try: + kv_cache_manager.coordinator.free(request_id) + logger.debug(f"Released KV cache for request {request_id}") + except Exception as e: + logger.warning(f"Error releasing KV cache for request {request_id}: {e}") diff --git a/src/parallax/vllm/model_runner.py b/src/parallax/vllm/model_runner.py new file mode 100644 index 00000000..cf25efe7 --- /dev/null +++ b/src/parallax/vllm/model_runner.py @@ -0,0 +1,570 @@ +from __future__ import annotations + +import importlib +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from mlx_lm.utils import load_config +from vllm.config import ( + CacheConfig, + CompilationConfig, + DeviceConfig, + LoadConfig, + ModelConfig, + ParallelConfig, + SchedulerConfig, + VllmConfig, +) +from vllm.distributed.parallel_state import GroupCoordinator as VLLMGroupCoordinator +from vllm.v1.core.kv_cache_manager import KVCacheManager +from vllm.v1.core.kv_cache_utils import ( + generate_scheduler_kv_cache_config, + get_kv_cache_configs, + get_request_block_hasher, + init_none_hash, +) +from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheGroupSpec, KVCacheTensor +from vllm.v1.worker.gpu_model_runner import GPUModelRunner + +from parallax.sglang.monkey_patch_utils.weight_loader_filter import ( + apply_weight_loader_filter_patch, + set_layer_range_for_filtering, +) +from parallax.utils.tokenizer_utils import load_tokenizer +from parallax.vllm.monkey_patch import apply_parallax_vllm_monkey_patch +from parallax_utils.logging_config import get_logger + +logger = get_logger(__name__) + + +class ParallaxVLLMGroupCoordinator(VLLMGroupCoordinator): + """ + Parallax version of vLLM's GroupCoordinator. + Override is_first_rank and is_last_rank to use layer ranges instead of process ranks. + """ + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, torch.distributed.Backend], + use_device_communicator: bool, + use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, + pp_start_layer: int = 0, + pp_end_layer: int = 0, + num_hidden_layers: int = 0, + ): + super().__init__( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=torch_distributed_backend, + use_device_communicator=use_device_communicator, + use_message_queue_broadcaster=use_message_queue_broadcaster, + group_name=group_name, + ) + self.pp_start_layer = pp_start_layer + self.pp_end_layer = pp_end_layer + self.num_hidden_layers = num_hidden_layers + + @property + def is_first_rank(self) -> bool: + """Return whether this is the first pipeline stage based on layer range.""" + return self.pp_start_layer == 0 + + @property + def is_last_rank(self) -> bool: + """Return whether this is the last pipeline stage based on layer range.""" + return self.pp_end_layer >= self.num_hidden_layers + + +def _create_kv_cache_config_from_specs( + kv_cache_group: KVCacheGroupSpec, + attn_layers: List[str], + kv_cache_memory_fraction: float, +) -> KVCacheConfig: + import torch + + free_memory, total_memory = torch.cuda.mem_get_info(0) + available_memory = int(free_memory * kv_cache_memory_fraction) + + logger.info( + f"Available GPU memory for KV cache: " + f"{available_memory / (1024**3):.2f} GB " + f"({kv_cache_memory_fraction:.1%} of {free_memory / (1024**3):.2f} GB)" + ) + + page_size_bytes = kv_cache_group.kv_cache_spec.page_size_bytes + + max_blocks_by_memory = available_memory // page_size_bytes + + num_blocks = max(100, min(1000, int(max_blocks_by_memory * 0.8))) + + logger.debug(f"Calculated KV cache blocks: {num_blocks} (max possible: {max_blocks_by_memory})") + + tensor_size_bytes = page_size_bytes * num_blocks + + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[ + KVCacheTensor( + size=tensor_size_bytes, + shared_by=attn_layers, + ) + ], + kv_cache_groups=[kv_cache_group], + ) + + return kv_cache_config + + +class ParallaxVLLMModelRunner(GPUModelRunner): + + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_config: Optional[KVCacheConfig], + device: str, + start_layer: int, + end_layer: int, + num_hidden_layers: int, + ): + self.start_layer = start_layer + self.end_layer = end_layer + self.num_hidden_layers = num_hidden_layers + self.num_shard_layers = end_layer - start_layer + + self.is_first_peer = start_layer == 0 + self.is_last_peer = end_layer == num_hidden_layers + + self.pp_rank = 0 + self.pp_size = 1 + + self.request_block_hasher: Optional[Callable[[Any], List[Any]]] = None + self.enable_prefix_caching: bool = True + + super().__init__(vllm_config=vllm_config, device=torch.device(device)) + self.kv_cache_config = kv_cache_config + + logger.info( + f"ParallaxVLLMModelRunner initialized: layers [{start_layer}, {end_layer}), " + f"is_first={self.is_first_peer}, is_last={self.is_last_peer}" + ) + + def _create_kv_cache_config(self, kv_cache_memory_fraction: float = None) -> KVCacheConfig: + logger.debug("Generating KV cache configuration from model...") + + try: + kv_cache_specs = self.model.get_kv_cache_spec() + except AttributeError: + logger.warning( + "Cannot access get_kv_cache_spec due to cudagraph wrapper, using fallback method" + ) + kv_cache_specs = None + + import torch + + free_memory, total_memory = torch.cuda.mem_get_info(self.device.index or 0) + + memory_fraction = ( + kv_cache_memory_fraction + if kv_cache_memory_fraction is not None + else self.cache_config.gpu_memory_utilization + ) + available_memory = int(free_memory * memory_fraction) + + logger.debug( + f"Available GPU memory for KV cache: " + f"{available_memory / (1024**3):.2f} GB " + f"({memory_fraction:.1%} of {free_memory / (1024**3):.2f} GB)" + ) + + if kv_cache_specs is not None: + kv_cache_configs = get_kv_cache_configs( + vllm_config=self.vllm_config, + kv_cache_specs=[kv_cache_specs], + available_memory=[available_memory], + ) + kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs) + else: + logger.debug("Using fallback KV cache configuration") + + model = self.model + hf_config = model.model.config + num_attention_heads = getattr(hf_config, "num_attention_heads", 8) + hidden_size = getattr(hf_config, "hidden_size", 1024) + head_size = hidden_size // num_attention_heads + + from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheGroupSpec + + model_dtype = self.vllm_config.model_config.dtype + if isinstance(model_dtype, str): + try: + from vllm.utils.torch_utils import ( + STR_DTYPE_TO_TORCH_DTYPE, # type: ignore + ) + except Exception: + # Older/newer vLLM versions may not expose torch_utils. + # Fall back silently and default to float16. + STR_DTYPE_TO_TORCH_DTYPE = {} + model_dtype = STR_DTYPE_TO_TORCH_DTYPE.get(model_dtype, torch.float16) + + kv_cache_group = KVCacheGroupSpec( + layer_names=[f"model.layers.{i}" for i in range(self.start_layer, self.end_layer)], + kv_cache_spec=FullAttentionSpec( + block_size=self.cache_config.block_size, + num_kv_heads=num_attention_heads, + head_size=head_size, + dtype=model_dtype, + ), + ) + + layer_names = [f"model.layers.{i}" for i in range(self.start_layer, self.end_layer)] + + kv_cache_config = _create_kv_cache_config_from_specs( + kv_cache_group=kv_cache_group, + attn_layers=layer_names, + kv_cache_memory_fraction=memory_fraction, + ) + + logger.debug( + f"KV cache config generated: " + f"num_blocks={kv_cache_config.num_blocks}, " + f"num_groups={len(kv_cache_config.kv_cache_groups)}" + ) + + return kv_cache_config + + def initialize_kv_cache_manager(self, max_model_len: int) -> KVCacheManager: + logger.debug("Initializing vLLM KVCacheManager...") + + if self.kv_cache_config is None: + self.kv_cache_config = self._create_kv_cache_config() + + kv_cache_manager = KVCacheManager( + kv_cache_config=self.kv_cache_config, + max_model_len=max_model_len, + enable_caching=True, + use_eagle=False, + log_stats=True, + enable_kv_cache_events=False, + dcp_world_size=1, + ) + + self.kv_cache_manager = kv_cache_manager + cache_config = self.vllm_config.cache_config + enable_prefix = cache_config.enable_prefix_caching + if enable_prefix is None: + enable_prefix = True + + self.enable_prefix_caching = False + + self.request_block_hasher = None + if enable_prefix and kv_cache_manager.block_size is not None: + try: + hashing_mod = importlib.import_module("vllm.utils.hashing") + get_hash_fn_by_name: Callable[[str], Callable[[Any], bytes]] = getattr( + hashing_mod, "get_hash_fn_by_name" + ) + hash_fn = get_hash_fn_by_name(cache_config.prefix_caching_hash_algo) + init_none_hash(hash_fn) + except (ModuleNotFoundError, AttributeError) as exc: + logger.warning("Unable to initialize prefix cache hashing: %s", exc) + + def simple_hash_fn(obj: Any) -> bytes: + return str(hash(str(obj))).encode("utf-8") + + hash_fn = simple_hash_fn + logger.info("Using simple fallback hash function for prefix caching") + + block_size = kv_cache_manager.block_size + if block_size is None and self.kv_cache_config.kv_cache_groups: + block_size = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size + if block_size is not None: + self.request_block_hasher = get_request_block_hasher(block_size, hash_fn) + logger.info("Initialized prefix cache block hasher with block_size=%d", block_size) + + logger.debug( + f"KVCacheManager initialized: block_size={kv_cache_manager.block_size}, " + f"usage={kv_cache_manager.usage:.2%}" + ) + + return kv_cache_manager + + def load_model(self) -> None: + logger.debug(f"Loading vLLM model with layers [{self.start_layer}, {self.end_layer})...") + + from vllm.distributed.utils import get_pp_indices + + original_get_pp_indices = get_pp_indices + + def custom_get_pp_indices(num_layers: int, rank: int, world_size: int): + logger.debug( + f"custom_get_pp_indices called: num_layers={num_layers}, " + f"returning [{self.start_layer}, {self.end_layer})" + ) + return self.start_layer, self.end_layer + + import vllm.distributed.utils + + vllm.distributed.utils.get_pp_indices = custom_get_pp_indices + + try: + super().load_model() + logger.debug( + f"Successfully loaded {self.num_shard_layers} layers " + f"[{self.start_layer}:{self.end_layer}]" + ) + + finally: + vllm.distributed.utils.get_pp_indices = original_get_pp_indices + + def execute_model(self, scheduler_output, intermediate_tensors=None): + """ + Execute the model with the given scheduler output and intermediate tensors. + If this is not the first peer, and the intermediate_tensors buffer is not initialized, + initialize it. + """ + if not self.is_first_peer and self.intermediate_tensors is None: + self.intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=self.max_num_tokens, + dtype=self.model_config.dtype, + device=self.device, + ) + logger.debug("Successfully initialized intermediate_tensors buffer") + + return super().execute_model(scheduler_output, intermediate_tensors) + + +def initialize_vllm_model_runner( + model_repo: str, + start_layer: int, + end_layer: int, + kv_cache_memory_fraction: float, + attention_backend: str, + kv_block_size: int, + max_num_tokens_per_batch: int = 1024, + dtype: str = "float16", + **kwargs, +) -> Tuple[ParallaxVLLMModelRunner, Dict, Any]: + from parallax.utils.selective_download import get_model_path_with_selective_download + + logger.info( + f"Initializing vLLM model runner for {model_repo}, " f"layers=[{start_layer}, {end_layer})" + ) + + model_path = get_model_path_with_selective_download( + model_repo, + start_layer=start_layer, + end_layer=end_layer, + ) + + config = load_config(model_path) + tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) + dtype = config.get("torch_dtype", "bfloat16") + + num_hidden_layers = config.get("num_hidden_layers") + is_first_peer = start_layer == 0 + is_last_peer = end_layer == num_hidden_layers + + # Apply Parallax vLLM monkey patches for pipeline parallelism + try: + apply_parallax_vllm_monkey_patch(is_first_stage=is_first_peer, is_last_stage=is_last_peer) + logger.debug( + f"Applied Parallax vLLM monkey patches: is_first_stage={is_first_peer}, is_last_stage={is_last_peer}" + ) + except Exception as e: + logger.warning("Failed to apply Parallax vLLM monkey patches: %s", e) + + # Apply layer-range-based weight file filtering before any model load. + # Reuse the generic monkey patch used by sglang implementation to reduce + # local weight file reads when loading a partial layer shard. + try: + set_layer_range_for_filtering(start_layer, end_layer, num_hidden_layers) + apply_weight_loader_filter_patch() + logger.debug( + f"Applied weight loader filter monkey patch for layers [{start_layer}, {end_layer})" + ) + except Exception as e: + logger.warning("Failed to apply weight loader filter patch for vLLM loading: %s", e) + + # For single process, always use pp_size=1 + virtual_pp_size = 1 + + import os + + import vllm.distributed.parallel_state as parallel_state + + if not parallel_state.model_parallel_is_initialized(): + logger.debug(f"Initializing vLLM distributed environment...") + + # Set environment variables for distributed initialization + if "RANK" not in os.environ: + os.environ["RANK"] = "0" + if "WORLD_SIZE" not in os.environ: + os.environ["WORLD_SIZE"] = "1" + if "LOCAL_RANK" not in os.environ: + os.environ["LOCAL_RANK"] = "0" + if "MASTER_ADDR" not in os.environ: + os.environ["MASTER_ADDR"] = "localhost" + if "MASTER_PORT" not in os.environ: + os.environ["MASTER_PORT"] = "12355" + + try: + parallel_state.init_distributed_environment() + + # Initialize with pp_size=1 for single process + parallel_state.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + ) + + # Monkey patch the PP group with our custom Parallax coordinator + # that uses layer ranges to determine is_first_rank/is_last_rank + original_pp_group = parallel_state._PP + if original_pp_group is not None: + # Get backend from device_group (torch is already imported at module level) + import torch.distributed + + backend = torch.distributed.get_backend(original_pp_group.device_group) + + # Create a Parallax PP group coordinator + # Need to wrap ranks in a list of lists for group_ranks parameter + parallax_pp_group = ParallaxVLLMGroupCoordinator( + group_ranks=[original_pp_group.ranks], + local_rank=original_pp_group.local_rank, + torch_distributed_backend=backend, + use_device_communicator=original_pp_group.use_device_communicator, + use_message_queue_broadcaster=(original_pp_group.mq_broadcaster is not None), + group_name="pp", + pp_start_layer=start_layer, + pp_end_layer=end_layer, + num_hidden_layers=num_hidden_layers, + ) + # Replace the PP group + parallel_state._PP = parallax_pp_group + logger.debug( + f"Replaced vLLM PP group with Parallax coordinator: " + f"is_first_rank={parallax_pp_group.is_first_rank}, " + f"is_last_rank={parallax_pp_group.is_last_rank}" + ) + + logger.debug(f"vLLM distributed environment initialized") + except Exception as e: + logger.warning(f"Failed to initialize distributed environment: {e}") + logger.error(f"vLLM distributed initialization failed. Error: {e}") + raise + + if end_layer > num_hidden_layers: + raise ValueError( + f"end_layer ({end_layer}) cannot be greater than " + f"num_hidden_layers ({num_hidden_layers})" + ) + + model_config = ModelConfig( + model=str(model_path), + tokenizer=str(model_path), + tokenizer_mode="auto", + trust_remote_code=True, + dtype=dtype, + seed=0, + max_model_len=getattr(config, "max_position_embeddings", 4096), + ) + + cache_config = CacheConfig( + block_size=kv_block_size, + gpu_memory_utilization=kv_cache_memory_fraction, + swap_space=0, + cache_dtype="auto", + ) + + parallel_config = ParallelConfig( + pipeline_parallel_size=virtual_pp_size, + tensor_parallel_size=1, + distributed_executor_backend=None, + ) + + device_config = DeviceConfig(device="cuda") + load_config_for_config = LoadConfig(load_format="auto") + + max_batched_tokens = max(max_num_tokens_per_batch, model_config.max_model_len) + scheduler_config = SchedulerConfig( + max_num_batched_tokens=max_batched_tokens, + max_num_seqs=256, + max_model_len=model_config.max_model_len, + ) + + vllm_config = VllmConfig( + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + load_config=load_config_for_config, + lora_config=None, + speculative_config=None, + observability_config=None, + prompt_adapter_config=None, + quant_config=None, + compilation_config=CompilationConfig(), + kv_transfer_config=None, + kv_events_config=None, + additional_config={}, + instance_id="", + ) + + model_runner = ParallaxVLLMModelRunner( + vllm_config=vllm_config, + kv_cache_config=None, + device="cuda", + start_layer=start_layer, + end_layer=end_layer, + num_hidden_layers=num_hidden_layers, + ) + + logger.info("Loading vLLM model (partial layers)...") + model_runner.load_model() + logger.info("vLLM model loaded successfully") + + logger.debug("Letting vLLM automatically generate KV cache configuration...") + + kv_cache_specs = model_runner.get_kv_cache_spec() + + if not kv_cache_specs: + raise RuntimeError("No KV cache specs found in the loaded model") + + import torch + + free_memory, total_memory = torch.cuda.mem_get_info(0) + available_memory = int(free_memory * kv_cache_memory_fraction) + + logger.info( + f"Available GPU memory for KV cache: " + f"{available_memory / (1024**3):.2f} GB " + f"({kv_cache_memory_fraction:.1%} of {free_memory / (1024**3):.2f} GB)" + ) + + from vllm.v1.core.kv_cache_utils import ( + generate_scheduler_kv_cache_config, + get_kv_cache_configs, + ) + + kv_cache_configs = get_kv_cache_configs( + vllm_config=model_runner.vllm_config, + kv_cache_specs=[kv_cache_specs], + available_memory=[available_memory], + ) + + kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs) + + model_runner.kv_cache_config = kv_cache_config + + logger.info("Initializing GPUModelRunner KV cache...") + model_runner.initialize_kv_cache(kv_cache_config) + logger.info("GPUModelRunner KV cache initialized successfully") + + logger.info("Initializing KV Cache Manager...") + model_runner.initialize_kv_cache_manager(max_model_len=model_config.max_model_len) + logger.info("KV Cache Manager initialized successfully") + + return model_runner, config, tokenizer diff --git a/src/parallax/vllm/monkey_patch.py b/src/parallax/vllm/monkey_patch.py new file mode 100644 index 00000000..5c098730 --- /dev/null +++ b/src/parallax/vllm/monkey_patch.py @@ -0,0 +1,27 @@ +""" +Monkey patches for vLLM to support Parallax pipeline parallelism. + +This module provides a unified entry point for applying all vLLM-related monkey patches +required for Parallax's distributed inference with pipeline parallelism. +""" + +from parallax.vllm.monkey_patch_utils.weight_loader import ( + apply_vllm_weight_loader_patch, + set_vllm_pipeline_stage, +) + + +## Here are patch functions for vLLM +## Hopefully, when vLLM supports pipeline parallelism natively in the way we need, +## we can remove these patches +def apply_parallax_vllm_monkey_patch(is_first_stage: bool, is_last_stage: bool): + """ + Apply all Parallax monkey patches for vLLM. + + Args: + is_first_stage: Whether this is the first pipeline stage. + is_last_stage: Whether this is the last pipeline stage. This affects + whether lm_head weights are expected to be loaded. + """ + set_vllm_pipeline_stage(is_first_stage, is_last_stage) + apply_vllm_weight_loader_patch() diff --git a/src/parallax/vllm/monkey_patch_utils/weight_loader.py b/src/parallax/vllm/monkey_patch_utils/weight_loader.py new file mode 100644 index 00000000..45e6fc87 --- /dev/null +++ b/src/parallax/vllm/monkey_patch_utils/weight_loader.py @@ -0,0 +1,98 @@ +""" +Monkey patch for vLLM weight loading to skip non-existent weights on different pipeline stages. +This is similar to the approach used in sglang monkey patches. +""" + +import logging +from typing import Any + +logger = logging.getLogger(__name__) + +_vllm_patch_applied = False +_is_first_stage = False # Default to False +_is_last_stage = True # Default to True for safety + + +def set_vllm_pipeline_stage(is_first_stage: bool, is_last_stage: bool): + """Set whether this is the first and/or last pipeline stage.""" + global _is_first_stage, _is_last_stage + _is_first_stage = is_first_stage + _is_last_stage = is_last_stage + logger.debug( + f"Set vLLM pipeline stage: is_first_stage={_is_first_stage}, is_last_stage={_is_last_stage}" + ) + + +def apply_vllm_weight_loader_patch(): + """ + Apply monkey patch to vLLM's default loader to skip initialization checks + for weights that are not expected on certain pipeline stages. + + - Skips `embed_tokens` check on non-first stages. + - Skips `lm_head` check on non-last stages. + """ + global _vllm_patch_applied + + if _vllm_patch_applied: + logger.debug("vLLM weight loader patch already applied, skipping") + return + + try: + from vllm.model_executor.model_loader import default_loader + + original_load_weights = default_loader.DefaultModelLoader.load_weights + + def patched_load_weights(self, model: Any, model_config: Any): + """Patched load_weights that handles embed_tokens and lm_head for pipeline parallelism.""" + global _is_first_stage, _is_last_stage + + try: + # Call original load_weights + original_load_weights(self, model, model_config) + except ValueError as e: + error_msg = str(e) + uninitialized_weights = "not initialized from checkpoint" in error_msg + + # Case 1: embed_tokens.weight not found + if "model.embed_tokens.weight" in error_msg and uninitialized_weights: + if not _is_first_stage: + # Expected behavior for non-first pipeline stages + logger.info( + "Skipping embed_tokens.weight initialization check on non-first pipeline stage" + ) + else: + # This is the first stage, embed_tokens should be initialized + logger.error( + "embed_tokens.weight not initialized on first pipeline stage, this is an error" + ) + raise + + # Case 2: lm_head.weight not found + elif "lm_head.weight" in error_msg and uninitialized_weights: + if not _is_last_stage: + # Expected behavior for non-last pipeline stages + logger.info( + "Skipping lm_head.weight initialization check on non-last pipeline stage" + ) + else: + # This is the last stage, lm_head should be initialized + logger.error( + "lm_head.weight not initialized on last pipeline stage, this is an error" + ) + raise + + # Case 3: Other errors + else: + # Different error, re-raise + raise + + # Apply the patch + default_loader.DefaultModelLoader.load_weights = patched_load_weights + _vllm_patch_applied = True + logger.info("Successfully applied vLLM weight loader patch for pipeline parallelism") + + except ImportError as e: + logger.warning(f"Could not apply vLLM weight loader patch: {e}") + except Exception as e: + logger.error(f"Error applying vLLM weight loader patch: {e}") + raise