New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[enhancement] Polyak Averaging could be done faster #93
Comments
Why not, do you have some benchmark for that? I would first run line profiler and see where are the bottlenecks before optimizing things without knowing if they have an real impact. |
Sure, I will benchmark both simply because it's a small change and see if it makes any difference. |
After a quick run of line profiler, here are the result for SAC:
One is hard to reduce ( In the train method, it is more evenly separated, the most time consuming operations (each around 10%)
So optimizing polyak average sounds good ;) |
I tested these on sac, there is a good 1.5-1.8 speedup here. More on the GPU than the cpu because of data transfers. def fast_polyak(agent):
one = th.ones(1, requires_grad=False).to(agent.device)
for param, target_param in zip(agent.critic.parameters(), agent.critic_target.parameters()):
target_param.data.mul_(1-agent.tau)
target_param.data.addcmul_(param.data, one, value=agent.tau)
def slow_polyak(agent):
for param, target_param in zip(agent.critic.parameters(), agent.critic_target.parameters()):
target_param.data.copy_((1-agent.tau)*target_param.data + agent.tau*param.data)
# how openai does it in their codebase
def openai_polyak(agent):
for param, target_param in zip(agent.critic.parameters(), agent.critic_target.parameters()):
target_param.data.mul_(1-agent.tau)
target_param.data.add_(agent.tau*param.data)
This is actually quite large, at 1Million polyak updates, this shaves off 28 minutes for cpu and 2 hours 11 minutes on GPU. |
@partiallytyped Could you quickly try on cpu but with That's the only case where I did not see an improvement yet. |
This is rather minor, but polyak averaging in DQN/SAC/TD3 could be done faster with far fewer intermediate tensors using
torch.addcmul_
https://pytorch.org/docs/stable/torch.html#torch.addcmul.The text was updated successfully, but these errors were encountered: