Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is save checkpoint not yet supported for ppo ray trainer? #256

Open
mickel-liu opened this issue Mar 27, 2024 · 5 comments
Open

Is save checkpoint not yet supported for ppo ray trainer? #256

mickel-liu opened this issue Mar 27, 2024 · 5 comments

Comments

@mickel-liu
Copy link

When I set save_step other than -1, the program outputs an exception

self.actor.model, os.path.join(args.ckpt_path, "_actor"), tag, args.max_ckpt_num, args.max_ckpt_mem
AttributeError: 'Namespace' object has no attribute 'ckpt_path'

if global_step % args.save_steps == 0:
tag = f"global_step{global_step}"
self.strategy.save_ckpt(
self.actor.model, os.path.join(args.ckpt_path, "_actor"), tag, args.max_ckpt_num, args.max_ckpt_mem
)
self.strategy.save_ckpt(
self.critic, os.path.join(args.ckpt_path, "_critic"), tag, args.max_ckpt_num, args.max_ckpt_mem
)

These three args are indeed not included in train_ppo_ray.py and I don't see arg.save_path being used.

I did see this issue was mentioned in #133, wondering if there's any update.

@hijkzzz
Copy link
Collaborator

hijkzzz commented Mar 27, 2024

Yes, we haven't fully developed and tested this feature yet. Welcome contribution

@mickel-liu
Copy link
Author

i'm happy to look into it, but how have you guys been saving models?

@suehyunpark
Copy link

Hi @mickel-liu, have you figured this out? I have no choice but to use train_ppo_ray.py for PPO instead of train_ppo.py, because it doesn't OOM during model loading in my configuration. I am looking into ways to save checkpoints during/after training, and was hoping if you have delved into this feature as well.

@mickelliu
Copy link
Contributor

Hi @mickel-liu, have you figured this out? I have no choice but to use train_ppo_ray.py for PPO instead of train_ppo.py, because it doesn't OOM during model loading in my configuration. I am looking into ways to save checkpoints during/after training, and was hoping if you have delved into this feature as well.

Hi, I did look into the code and found out the saving checkpoints feature is not yet implemented. But actually saving checkpoints wasn't what I was looking for, I want the actual model checkpoints, not the intermediate states as being referred in this repo. So I ended up changing the code on my fork and now it saves model checkpoints after a pre-set amount of iterations. Here's the code in my fork: https://github.com/mickelliu/OpenRLHF/blob/a7f21aa26ac027fcf30ca1c588e01cf07c67cb6f/openrlhf/trainer/ppo_trainer.py#L428-L442

Regardless of ckpt feature is being officially implemented, train_ppo_ray.py will save a model checkpoint at the end of the training.

@suehyunpark
Copy link

Hi @mickel-liu, have you figured this out? I have no choice but to use train_ppo_ray.py for PPO instead of train_ppo.py, because it doesn't OOM during model loading in my configuration. I am looking into ways to save checkpoints during/after training, and was hoping if you have delved into this feature as well.

Hi, I did look into the code and found out the saving checkpoints feature is not yet implemented. But actually saving checkpoints wasn't what I was looking for, I want the actual model checkpoints, not the intermediate states as being referred in this repo. So I ended up changing the code on my fork and now it saves model checkpoints after a pre-set amount of iterations. Here's the code in my fork: https://github.com/mickelliu/OpenRLHF/blob/a7f21aa26ac027fcf30ca1c588e01cf07c67cb6f/openrlhf/trainer/ppo_trainer.py#L428-L442

Regardless of ckpt feature is being officially implemented, train_ppo_ray.py will save a model checkpoint at the end of the training.

Thanks for the quick reply and for sharing your code! I'm glad to know that saving the trained model would be that simple. Although the checkpointing feature would be a great add, this fix seems to solve my issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants