diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index f449f582..dd0b819e 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -57,7 +57,6 @@ pad_prefix_caches, ) from parallax_utils.logging_config import get_logger -from parallax_utils.utils import compute_max_batch_size logger = get_logger(__name__) @@ -73,7 +72,7 @@ def __init__( end_layer: int, dtype: str = "float16", # Scheduler Configs - max_batch_size: Optional[int] = None, + max_batch_size: Optional[int] = 8, max_sequence_length: Optional[int] = None, max_tokens_in_kv_pool: Optional[int] = None, # Controlling perfill / decode ratio @@ -222,16 +221,17 @@ def __init__( self.kv_cache_manager.max_num_tokens # Scheduler: derive final max_batch_size with KV constraints - max_batch_size = compute_max_batch_size( - requested_max_batch_size=max_batch_size, - max_sequence_len=max_sequence_length, - device=self.device, - kv_cache_memory_fraction=kv_cache_memory_fraction, - num_shard_layers=self.num_shard_layers, - num_key_value_heads=self.num_key_value_heads, - head_dim=self.head_dim, - dtype=self.dtype, - ) + # Remove this for now as it's not working on gpu devices + # max_batch_size = compute_max_batch_size( + # requested_max_batch_size=max_batch_size, + # max_sequence_len=max_sequence_length, + # device=self.device, + # kv_cache_memory_fraction=kv_cache_memory_fraction, + # num_shard_layers=self.num_shard_layers, + # num_key_value_heads=self.num_key_value_heads, + # head_dim=self.head_dim, + # dtype=self.dtype, + # ) self.scheduler = Scheduler( max_batch_size=max_batch_size, diff --git a/src/parallax/server/scheduler.py b/src/parallax/server/scheduler.py index 583939dd..fd500b81 100644 --- a/src/parallax/server/scheduler.py +++ b/src/parallax/server/scheduler.py @@ -45,7 +45,7 @@ def __init__( """ self.max_batch_size = max_batch_size self.max_num_tokens_per_batch = max_num_tokens_per_batch - self.micro_batch_size = max_batch_size // micro_batch_ratio + self.micro_batch_size = max(1, max_batch_size // micro_batch_ratio) self.scheduler_wait_ms = scheduler_wait_ms self.is_first_peer = is_first_peer if is_first_peer: diff --git a/src/parallax/server/server_args.py b/src/parallax/server/server_args.py index 8d0e1839..7b8eead8 100644 --- a/src/parallax/server/server_args.py +++ b/src/parallax/server/server_args.py @@ -103,7 +103,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--max-batch-size", type=int, - default=None, + default=8, help="Maximum batch size for processing requests", )