Skip to content

Commit 147168c

Browse files
Merge pull request #3692 from AI-Hypercomputer:hengtaoguo-rl-ep
PiperOrigin-RevId: 901473973
2 parents a049f9a + 16c444f commit 147168c

2 files changed

Lines changed: 2 additions & 1 deletion

File tree

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1731,6 +1731,7 @@ class VLLM(BaseModel):
17311731
hbm_utilization_vllm: float = Field(0.72, description="Target HBM utilization for vLLM.")
17321732
swap_space_vllm_gb: int = Field(2, description="Swap space in GB for vLLM.")
17331733
enable_dp_attention: bool = Field(False, description="Enable the attn_dp mesh axis in vLLM.")
1734+
enable_expert_parallel: bool = Field(False, description="Enable expert parallelism in vLLM.")
17341735
async_scheduling: bool = Field(False, description="Enable asynchronous scheduling in vLLM.")
17351736
max_num_batched_tokens: Optional[int] = Field(None, description="Max number of batched tokens in vLLM.")
17361737
max_num_seqs: Optional[int] = Field(None, description="Max number of sequences in vLLM.")

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,7 @@ def create_rl_components(
541541
rollout_vllm_server_mode=trainer_config.rl.use_agentic_rollout,
542542
rollout_vllm_kwargs={
543543
"hf_overrides": trainer_config.vllm_hf_overrides,
544-
"enable_expert_parallel": sampler_config.rollout_expert_parallelism > 1,
544+
"enable_expert_parallel": sampler_config.enable_expert_parallel,
545545
"enable_prefix_caching": True, # Enable prefix caching to speed up generation for long prompts
546546
},
547547
rollout_vllm_sampling_kwargs={

0 commit comments

Comments
 (0)