Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RecurrentPPO: 9x speedup - whole sequence batching #118

Draft
wants to merge 28 commits into
base: master
Choose a base branch
from

Conversation

b-vm
Copy link

@b-vm b-vm commented Nov 28, 2022

Description

Moving from 2d batches to 3d batches of whole sequences leads to a 5-9 times speedup in terms of fps while keeping results similar. Proof.

Context

  • I have raised an issue to propose this change (required)

Types of changes

Its currently implemented as an additional feature but would probably be more optimal to replace the original.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist:

  • I've read the CONTRIBUTION guide (required)
  • The functionality/performance matches that of the source (required for new training algorithms or training-related features).
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have included an example of using the feature (required for new features).
  • I have included baseline results (required for new training algorithms or training-related features).
  • I have updated the documentation accordingly.
  • I have updated the changelog accordingly (required).
  • I have reformatted the code using make format (required)
  • I have checked the codestyle using make check-codestyle and make lint (required)
  • I have ensured make pytest and make type both pass. (required)

Note: we are using a maximum length of 127 characters per line

@b-vm
Copy link
Author

b-vm commented Dec 16, 2022

@araffin have you been able to take a look at this yet? I am very curious what you think about it.

@araffin
Copy link
Member

araffin commented Dec 16, 2022

have you been able to take a look at this yet? I am very curious what you think about it.

no, not yet, still on my stack... and going on holidays soon, so, I'll probably take a look next week or in january.

@b-vm
Copy link
Author

b-vm commented Dec 29, 2022

Cool. Let me know if you need any help running experiments/coding

@araffin araffin mentioned this pull request Feb 18, 2023
@b-vm b-vm changed the title RecurrentPPO: Whole sequence batching RecurrentPPO: 9x speedup - whole sequence batching Mar 1, 2023
@araffin araffin self-requested a review March 7, 2023 23:00
@araffin
Copy link
Member

araffin commented Apr 3, 2023

Hello,
I tried but couldn't test the PR, I got an error (before my changes) both with Pendulum and BipedalWalker:

Traceback (most recent call last):
  File "sb3_contrib/whole_sequence_speed_test.py", line 167, in <module>
    model.learn(2e5, tb_log_name=f"PendulumNoVel-v1_whole_sequences_batch_size{batch_size}")
  File "sb3_contrib/sb3_contrib/ppo_recurrent/ppo_recurrent.py", line 505, in learn
    self.train()
  File "sb3_contrib/sb3_contrib/ppo_recurrent/ppo_recurrent.py", line 361, in train
    values, log_prob, entropy = self.policy.evaluate_actions_whole_sequence(
  File "sb3_contrib/sb3_contrib/common/recurrent/policies.py", line 372, in evaluate_actions_whole_sequence
    latent_pi, _ = self.lstm_actor(features)
  File "mambaforge/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "mambaforge/lib/python3.10/site-packages/torch/nn/modules/rnn.py", line 810, in forward
    self.check_forward_args(input, hx, batch_sizes)
  File "mambaforge/lib/python3.10/site-packages/torch/nn/modules/rnn.py", line 730, in check_forward_args
    self.check_input(input, batch_sizes)
  File "mambaforge/lib/python3.10/site-packages/torch/nn/modules/rnn.py", line 218, in check_input
    raise RuntimeError(
RuntimeError: input.size(-1) must be equal to input_size. Expected 3, got 6

@b-vm
Copy link
Author

b-vm commented Apr 12, 2023

My bad. Bug is fixed now!

@araffin
Copy link
Member

araffin commented Apr 27, 2023

I had to set drop_last=False sometimes, otherwise I was getting error due to the fact nothing was sampled:
UnboundLocalError: local variable 'loss' referenced before assignment

To reproduce:

python -m rl_zoo3.train --algo ppo_lstm --env PendulumNoVel-v1 -params whole_sequences:True use_sde:False
python -m rl_zoo3.train --algo ppo_lstm --env CartPoleNoVel-v1 -params whole_sequences:True

On CartPole, I have another error:

Traceback (most recent call last):
  File "torchy-zoo/train.py", line 4, in <module>
    train()
  File "torchy-zoo/rl_zoo3/train.py", line 267, in train
    exp_manager.learn(model)
  File "torchy-zoo/rl_zoo3/exp_manager.py", line 236, in learn
    model.learn(self.n_timesteps, **kwargs)
  File "sb3_contrib/sb3_contrib/ppo_recurrent/ppo_recurrent.py", line 521, in learn
    self.train()
  File "sb3_contrib/sb3_contrib/ppo_recurrent/ppo_recurrent.py", line 377, in train
    values, log_prob, entropy = self.policy.evaluate_actions_whole_sequence(
  File "sb3_contrib/sb3_contrib/common/recurrent/policies.py", line 387, in evaluate_actions_whole_sequence
    log_prob = distribution.distribution.log_prob(actions).sum(dim=-1)
  File "mambaforge/lib/python3.10/site-packages/torch/distributions/categorical.py", line 123, in log_prob
    self._validate_sample(value)
  File "mambaforge/lib/python3.10/site-packages/torch/distributions/distribution.py", line 288, in _validate_sample
    raise ValueError('Value is not broadcastable with batch_shape+event_shape: {} vs {}.'.
ValueError: Value is not broadcastable with batch_shape+event_shape: torch.Size([32, 15, 1]) vs torch.Size([32, 15]).

Also, SDE seems not supported (that's ok, but need to be checked at runtime).

Finally, I experienced some NaN issue from time to time when drop_last=False (I fixed that by deactivating advantage normalization) :

ValueError: Expected parameter loc (Tensor of shape (4, 1)) of distribution Normal(loc: torch.Size([4, 1]), scale: torch.Size([4, 1])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan],
        [nan],
        [nan],
        [nan]])

@araffin
Copy link
Member

araffin commented Apr 27, 2023

Also an error when using CNN:

python train.py --algo ppo_lstm --env CarRacing-v2 -P --n-eval-envs 5 --eval-episodes 20 -params batch_size:8 whole_sequences:True
    self.train()
  File "/home/antonin/Documents/rl/sb3-contrib/sb3_contrib/ppo_recurrent/ppo_recurrent.py", line 377, in train
    values, log_prob, entropy = self.policy.evaluate_actions_whole_sequence(
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/antonin/Documents/rl/sb3-contrib/sb3_contrib/common/recurrent/policies.py", line 371, in evaluate_actions_whole_sequence
    features = self.extract_features(obs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/antonin/Documents/rl/stable-baselines3/stable_baselines3/common/policies.py", line 640, in extract_features
    return super().extract_features(obs, self.features_extractor)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/antonin/Documents/rl/stable-baselines3/stable_baselines3/common/policies.py", line 131, in extract_features
    return features_extractor(preprocessed_obs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/antonin/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1482, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/antonin/Documents/rl/stable-baselines3/stable_baselines3/common/torch_layers.py", line 106, in forward
    return self.linear(self.cnn(observations))
                       ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/antonin/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1482, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/antonin/miniconda3/lib/python3.11/site-packages/torch/nn/modules/container.py", line 204, in forward
    input = module(input)
            ^^^^^^^^^^^^^
  File "/home/antonin/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1482, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/antonin/miniconda3/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 463, in forward
    return self._conv_forward(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/antonin/miniconda3/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [281, 8, 2, 64, 64]

@araffin
Copy link
Member

araffin commented Oct 6, 2023

On CartPole, I have another error:

The error for CartPole seems to be still there...

@b-vm
Copy link
Author

b-vm commented Oct 15, 2023

Yes, it has only been implemented for Box action spaces so that might be it.

I have not much time to work on this anymore. So feel free to do it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants