Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions src/parallax/server/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
pad_prefix_caches,
)
from parallax_utils.logging_config import get_logger
from parallax_utils.utils import compute_max_batch_size

logger = get_logger(__name__)

Expand All @@ -73,7 +72,7 @@ def __init__(
end_layer: int,
dtype: str = "float16",
# Scheduler Configs
max_batch_size: Optional[int] = None,
max_batch_size: Optional[int] = 8,
max_sequence_length: Optional[int] = None,
max_tokens_in_kv_pool: Optional[int] = None,
# Controlling perfill / decode ratio
Expand Down Expand Up @@ -222,16 +221,17 @@ def __init__(
self.kv_cache_manager.max_num_tokens

# Scheduler: derive final max_batch_size with KV constraints
max_batch_size = compute_max_batch_size(
requested_max_batch_size=max_batch_size,
max_sequence_len=max_sequence_length,
device=self.device,
kv_cache_memory_fraction=kv_cache_memory_fraction,
num_shard_layers=self.num_shard_layers,
num_key_value_heads=self.num_key_value_heads,
head_dim=self.head_dim,
dtype=self.dtype,
)
# Remove this for now as it's not working on gpu devices
# max_batch_size = compute_max_batch_size(
# requested_max_batch_size=max_batch_size,
# max_sequence_len=max_sequence_length,
# device=self.device,
# kv_cache_memory_fraction=kv_cache_memory_fraction,
# num_shard_layers=self.num_shard_layers,
# num_key_value_heads=self.num_key_value_heads,
# head_dim=self.head_dim,
# dtype=self.dtype,
# )

self.scheduler = Scheduler(
max_batch_size=max_batch_size,
Expand Down
2 changes: 1 addition & 1 deletion src/parallax/server/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
"""
self.max_batch_size = max_batch_size
self.max_num_tokens_per_batch = max_num_tokens_per_batch
self.micro_batch_size = max_batch_size // micro_batch_ratio
self.micro_batch_size = max(1, max_batch_size // micro_batch_ratio)
self.scheduler_wait_ms = scheduler_wait_ms
self.is_first_peer = is_first_peer
if is_first_peer:
Expand Down
2 changes: 1 addition & 1 deletion src/parallax/server/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--max-batch-size",
type=int,
default=None,
default=8,
help="Maximum batch size for processing requests",
)

Expand Down