Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unexpected long actor_time when train_ppo_ray #246

Open
LSC527 opened this issue Mar 14, 2024 · 9 comments
Open

Unexpected long actor_time when train_ppo_ray #246

LSC527 opened this issue Mar 14, 2024 · 9 comments
Assignees
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@LSC527
Copy link

LSC527 commented Mar 14, 2024

训练配置如下:

--ref_num_nodes 1 --ref_num_gpus_per_node 2 --reward_num_nodes 1 --reward_num_gpus_per_node 2 --critic_num_nodes 1 --critic_num_gpus_per_node 4 --actor_num_nodes 2 --actor_num_gpus_per_node 8 --vllm_num_engines 2 --vllm_tensor_parallel_size 4 --micro_train_batch_size 4 --train_batch_size 64 --micro_rollout_batch_size 4 --rollout_batch_size 64 --max_epochs 1 --prompt_max_len 1024 --generate_max_len 1024 --zero_stage 3 --bf16 --adam_offload --flash_attn --gradient_checkpointing --perf

actor模型是70b llama2,critic模型是13b llama2。
在通过actor模型计算action_log_probs时发现耗时异常,actor_time高达150秒。

        # log probs
        start = time.time()
        action_log_probs = self.actor(sequences, num_actions, attention_mask)
        actor_time = time.time() - start

通过profile发现是由于actor模型计算action_log_probs的推理开始时出现了长达80秒的all_gather通信。
image
怀疑是多机通信问题,但额外perf了一下actor模型训练的耗时也只有50秒。不清楚actor模型耗时异常是什么导致的。

@hijkzzz
Copy link
Collaborator

hijkzzz commented Mar 14, 2024

收到,我们研究一下。最近工作比较忙,不一定顾得上~

@wuxibin89 wuxibin89 self-assigned this Mar 15, 2024
@wuxibin89 wuxibin89 added the enhancement New feature or request label Mar 15, 2024
@wuxibin89
Copy link
Collaborator

wuxibin89 commented Mar 15, 2024

@LSC527 如果开启vllm的话,因为训练和推理分离,所以actor model和critic model的推理和训练计算量是相当的,建议把两者的GPU数量调整成一致。

@wuxibin89
Copy link
Collaborator

通过profile发现是由于actor模型计算action_log_probs的推理开始时出现了长达80秒的all_gather通信。

这个all_gather通信的开销来自于actor和vllm参数同步,在训练阶段结束后,需要通过一次all_gather把参数收集到actor model的rank 0,然后broadcast给vllm的所有rank
https://github.com/OpenLLMAI/OpenRLHF/blob/main/openrlhf/trainer/ray/ppo_actor.py#L142-L145

@wuxibin89
Copy link
Collaborator

这里应该有有一定优化空间,现在是每个参数都需要经过一次all_gather+broadcast,可以把多个参数组成一个chunk,以减少通信次数

@wuxibin89 wuxibin89 added the help wanted Extra attention is needed label Mar 15, 2024
@LSC527
Copy link
Author

LSC527 commented Mar 15, 2024

@LSC527 如果开启vllm的话,因为训练和推理分离,所以actor model和critic model的推理和训练计算量是相当的,建议把两者的GPU数量调整成一致。

@wuxibin89 因为我actor模型是70b llama2,critic模型是13b llama2小很多,所以critic的GPU数量设置的少。

通过profile发现是由于actor模型计算action_log_probs的推理开始时出现了长达80秒的all_gather通信。

这个all_gather通信的开销来自于actor和vllm参数同步,在训练阶段结束后,需要通过一次all_gather把参数收集到actor model的rank 0,然后broadcast给vllm的所有rank https://github.com/OpenLLMAI/OpenRLHF/blob/main/openrlhf/trainer/ray/ppo_actor.py#L142-L145

_broadcast_to_vllm的开销为什么会在actor_time上体现呢?https://github.com/OpenLLMAI/OpenRLHF/blob/main/openrlhf/trainer/ppo_utils/experience_maker.py#L257-L259

@wuxibin89
Copy link
Collaborator

hmmm...我理解cuda kernel包括nccl通信应该都是异步执行的,所以actor_time这里可能触发了同步操作。
https://pytorch.org/docs/stable/notes/cuda.html
感觉可以在_broadcast_to_vllm之后加个torch.cuda.synchronize()验证一下
https://github.com/OpenLLMAI/OpenRLHF/blob/main/openrlhf/trainer/ray/ppo_actor.py#L145

@wuxibin89
Copy link
Collaborator

@LSC527 能否看一下,是每个micro rollout batch的actor_time都很长,还是只有第一个batch的耗时较长?

@LSC527
Copy link
Author

LSC527 commented Mar 20, 2024

@wuxibin89 每一个step,actor_time耗时都很长。并且我直接去掉_broadcast_to_vllm后仍然是这样。目前观察到这个现象会出现在actor_num_nodes>1 + zero3的场景下。

torch.distributed.barrier()

_broadcast_to_vllm后面有个torch.distributed.barrier(),所以耗时应该不会计算到actor_time里。actor_time看起来就是单纯的actor model zero3 forward耗时。
我再继续排查一下。

@LSC527
Copy link
Author

LSC527 commented Mar 21, 2024

@wuxibin89 最终在ray.get(llm.generate.remote())前后加了barrier,发现是这一行代码运行带来的额外耗时。如果没有加barrier,额外耗时会被记入actor_time中。

outputs = ray.get(llm.generate.remote(sampling_params=sampling_params, prompt_token_ids=prompt_token_ids))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants