-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug:] Cannot use the fused
flag in default optimizer of PPO
#1770
Comments
fused
flag in default optimizer of PPOfused
flag in default optimizer of PPO
Hello,
Could you please share some runs/runtime on standard environments? (to have a better idea of the potential gain) I would recommend you to use the RL Zoo and log things using W&B. (or do runs with Right now, the solution to your problem is to define a custom policy or to fork SB3. EDIT: if you want a significant runtime boost, you can have a look at https://github.com/araffin/sbx |
To do so, I would have to fix this issue locally and test on all the standard environments. I don't have the capacity for that just now, unfortunately. I did test this locally for my use case and I noticed a small improvement in very short runs, but not enough for me to spend much more time in fixing this issue upstream. The PyTorch docs (here) also state a speed-up, but don't quantify it. |
Hello, The benchmark is still needed to know if it's something that should be done on other algorithms or not, or if it should only be mentioned in the doc (with a link to your fork). Btw, as I mentioned before, if you want a real performance boost, you can have a look at https://github.com/araffin/sbx (usually faster when run on cpu only when not using CNN, see https://arxiv.org/abs/2310.05808) |
Actually my PR is generic to anything that inherits from I will update this issue and the PR with benchmark results when I am able to run them.
Thanks for the suggestion. I will consider sbx in the future, but for my current project I'll have to stick with SB3. |
@araffin The current github-actions based CI will skip our GPU tests, but it looks like GPU runners are in beta, starting last month. Just pointing it out since there is a waiting-list to signup to. See: |
I did an initial test running the following with and without the
Without With This was a single run. It shows a 5.5% reduction in runtime. But this is with a single run, so I don't know of the noise yet. The GPU is RTX A5000 and CPU is AMD Ryzen 9 7900X with hyperthreading disabled. |
Hello,
CPU only:
One cpu:
GPU, no fused:
GPU, fused:
SBX CPU (still using RL Zoo: https://rl-baselines3-zoo.readthedocs.io/en/master/guide/sbx.html):
SBX GPU:
As I wrote before, when not using CNN, PPO on CPU is usually the fastest (20-25% faster). For now:
|
For Atari games, with CNN
CPU: Time (mean ± σ): 93.623 s ± 0.901 s |
I ran this again 10 times over, with and without fused. The fused version only reduces the average elapsed time from 36:55 to 36:35, so that is only a 0.9% saving. It is still worth merging #1771, which doesn't introduce the fused flag, but makes it possible to use it if anyone wishes to. It also ensures the optimizers don't see parameter types changing underneath them, between initialization and first use. |
I ran the same on my machine with '-m 5' and I got:
So in this case I actually experience a slowdown with |
I should mention that although this is true even for my workload, that a single training instance goes nearly as fast on CPUs (20 of them) as it does on the GPU, it is not the case when I do hyperparameter tuning with many parallel instances of rl_zoo3 (in separate process, not using optuna threads). It seems to me that when the GPU use is enabled, the CPU usage of a single training instance is limited to a single core and the GPU is partially utilized. I am then able to run enough instances in parallel to max out my hardware. In contract, training a single instance on only the CPU seems to max out usage on all CPUs, so then it is not beneficial to run hyperparameter tuning in parallel. The difference between the two setups above, for me, is enormous. For me it is the difference between getting nothing at all versus getting some results after a weekend run of rl_zoo3 on my environment. That is why I still use the GPU with PPO. |
i think you should take a look at the run i did with a single cpu (to disable inter op parallelism) and related issues (search "num threads pytorch"). I appreciate the PR you did, but the current results don't justify the change. This would also introduce inconsistency between on/off policy and all algorithms in sb3 contrib would have to be adjusted too. |
Thanks for the tip, I will look into it.
Ok, no problem. Feel free to close this issue and the PR. Thanks for your inputs on this matter. |
Closing as not planned for now, will re-open in case of new results/other cases that justify the change. |
The default Adam optimizer has a
fused
flag, which, according to the docs, is significantly faster than the default when used on CUDA. Using it with PPO generates an exception, which complains that the parameters are not of type CUDA.The
fused
parameter can be specified to PPO usingpolicy_kwargs = dict(optimizer_kwargs={'fused': True})
.But, the issue is in the following lines of code:
stable-baselines3/stable_baselines3/common/on_policy_algorithm.py
Lines 133 to 136 in c8fda06
Before line 133 above, the correct device has been initialised in
self.device
. But, thepolicy_class
is initialized without it in line 133, so it initialises with thecpu
device, and that also initialises the optimizer with thecpu
device. In line 136, the device of thepolicy_class
is updated to the correct one, but by then it is too late, because the optimizer had already been initialized, and it thought the device wascpu
.This is a problem with the
fused
flag, because the Adam optimiser does check it and then double-checksself.parameters()
to ensure they are of the correct type, and complains, in my case, that it is not ofcuda
type.If the
policy_class
in line 133 above was passed the correct device (i.e.self.device
) in the initialization in the first place, it could set it correctly beforeMlpExtractor
gets initialized.MlpExtractor
gets initialized to the parent class's device in the lines below:stable-baselines3/stable_baselines3/common/policies.py
Lines 568 to 581 in c8fda06
Here is the traceback I get:
The text was updated successfully, but these errors were encountered: