Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use single forward pass in shared model architectures #156

Merged
merged 2 commits into from
Jun 15, 2024

Conversation

lopatovsky
Copy link
Contributor

@lopatovsky lopatovsky commented Jun 10, 2024

Single forward pass

Motivation:

When applying the shared model, forward pass is called twice, once for policy and once for value. The input values for the forward call are identical, so the output value could be cached to improve performance.

!Note: Single forward pass also influences the autograd graph construction, so the significant speedup happens also during the backward pass phase.

Speed eval:

  • Big neural network (units: [2048, 1024, 1024, 512])

  • 3840 steps

  • Running on top of Oige env simulation (constant for each run)

Library Single forward pass Time (s) slowing factor Base: rlgames, mixed pr. = True slowing factor Base: rlgames, mixed pr. = False
RLGamesmixed pr. = False Yes 141 1.259x 1 (base)
RLGamesmixed pr. = True Yes 112 1 (base) 0.794x
SKRL No 199 1.777x 1.411x
SKRL Yes 151 1.348x 1.071x

* Mixed precision = True

Quality eval:

We trained a policy for our task with each of the configurations multiple times. We didn’t observe any statistically significant difference in quality of the final results.

Notice: The single and double pass runs would be identical in ideal world, but because of finite double precision and different order of computation of gradient, they diverge gradually.  

Note:

- this implementation is minimalistic, but it’s quite dangerous to generalise, as it requires the value forward pass always follow the policy forward pass.
To make it safer we may implement caching of input and check if the next input is the same
-  a) check if they are reference to the same object

  • This is simple, but using a state_preprocessor breaks the reference. So it would either have slightly weaker performance or we would need to cache state_preprocessor as well

- b) compare input and cached input tensors directly. It brings some overhead in computation, but it’s negligible compared to time spared. 

@Toni-SM Toni-SM changed the base branch from main to develop June 13, 2024 00:50
@Toni-SM Toni-SM merged commit 32f25d6 into Toni-SM:develop Jun 15, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants