In [1]:
from typing import Callable, Generator, Optional, Tuple, Union, NamedTuple

import numpy as np
import torch as th
import jax.numpy as jnp
from gymnasium import spaces
# TODO : see later how to enable DictRolloutBuffer
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
from stable_baselines3.common.vec_env import VecNormalize
from stable_baselines3.common.env_util import make_vec_env


# TODO : add type aliases for the NamedTuple
class LSTMStates(NamedTuple):
    pi: Tuple
    vf: Tuple

# TODO : Replaced th.Tensor with jnp.ndarray but might not be true (some as still th Tensors because used in other sb3 fns)
# Added lstm states but also dones because they are used in actor and critic
class RecurrentRolloutBufferSamples(NamedTuple):
    observations: jnp.ndarray
    actions: jnp.ndarray
    old_values: jnp.ndarray
    old_log_prob: jnp.ndarray
    advantages: jnp.ndarray
    returns: jnp.ndarray
    dones: jnp.ndarray
    lstm_states: LSTMStates

# Add a recurrent buffer that also takes care of the lstm states and dones flags
class RecurrentRolloutBuffer(RolloutBuffer):
    """
    Rollout buffer that also stores the LSTM cell and hidden states.

    :param buffer_size: Max number of element in the buffer
    :param observation_space: Observation space
    :param action_space: Action space
    :param hidden_state_shape: Shape of the buffer that will collect lstm states
        (n_steps, lstm.num_layers, n_envs, lstm.hidden_size)
    :param device: PyTorch device
    :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
        Equivalent to classic advantage when set to 1.
    :param gamma: Discount factor
    :param n_envs: Number of parallel environments
    """

    def __init__(       
        self,
        buffer_size: int,
        observation_space: spaces.Space,
        action_space: spaces.Space,
        # renamed this because I found hidden_state_shape confusing
        lstm_state_buffer_shape: Tuple[int, int, int],
        device: Union[th.device, str] = "auto",
        gae_lambda: float = 1,
        gamma: float = 0.99,
        n_envs: int = 1,
    ):  
        self.hidden_state_shape = lstm_state_buffer_shape
        self.seq_start_indices, self.seq_end_indices = None, None
        super().__init__(buffer_size, observation_space, action_space, device, gae_lambda, gamma, n_envs)

    # TODO : remove dones because already episode starts in the buffer
    def reset(self):
        super().reset()
        # also add the dones and all lstm states
        self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.hidden_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32)
        self.cell_states_pi = np.zeros(self.hidden_state_shape, dtype=np.float32)
        self.hidden_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32)
        self.cell_states_vf = np.zeros(self.hidden_state_shape, dtype=np.float32)

    # TODO : remove dones because already episode starts in the buffer
    def add(self, *args, dones, lstm_states, **kwargs) -> None:
        """
        :param hidden_states: LSTM cell and hidden state
        """
        self.dones[self.pos] = np.array(dones)
        self.hidden_states_pi[self.pos] = np.array(lstm_states.pi[0])
        self.cell_states_pi[self.pos] = np.array(lstm_states.pi[1])
        self.hidden_states_vf[self.pos] = np.array(lstm_states.vf[0])
        self.cell_states_vf[self.pos] = np.array(lstm_states.vf[1])

        super().add(*args, **kwargs)

    def get(self, batch_size: Optional[int] = None) -> Generator[RecurrentRolloutBufferSamples, None, None]:
        assert self.full, "Rollout buffer must be full before sampling from it"

        # Prepare the data
        if not self.generator_ready:
            # hidden_state_shape = (self.n_steps, lstm.num_layers, self.n_envs, lstm.hidden_size)
            # swap first to (self.n_steps, self.n_envs, lstm.num_layers, lstm.hidden_size)
            for tensor in ["hidden_states_pi", "cell_states_pi", "hidden_states_vf", "cell_states_vf"]:
                self.__dict__[tensor] = self.__dict__[tensor].swapaxes(1, 2)

            # flatten but keep the sequence order
            # 1. (n_steps, n_envs, *tensor_shape) -> (n_envs, n_steps, *tensor_shape)
            # 2. (n_envs, n_steps, *tensor_shape) -> (n_envs * n_steps, *tensor_shape)
            for tensor in [
                "observations",
                "actions",
                "values",
                "log_probs",
                "advantages",
                "returns",
                "dones",
                "hidden_states_pi",
                "cell_states_pi",
                "hidden_states_vf",
                "cell_states_vf",
                "episode_starts",
            ]:
                self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
            self.generator_ready = True

        # Return everything, don't create minibatches
        if batch_size is None:
            batch_size = self.buffer_size * self.n_envs

        # TODO : See how to effectively use the indices to conserve temporal order in the batch data during updates
        # TODO : I think the easisest way is to ensure the n_steps is a multiple of batch_size
        # TODO : But still need to be fixed at the moment (I just made sure the returned shape was right)
        indices = np.arange(self.buffer_size * self.n_envs)

        start_idx = 0
        while start_idx < self.buffer_size * self.n_envs:
            batch_inds = indices[start_idx : start_idx + batch_size]
            yield self._get_samples(batch_inds)
            start_idx += batch_size

    # return the lstm states as an LSTMStates tuple
    def _get_samples(
        self,
        batch_inds: np.ndarray,
        env: Optional[VecNormalize] = None,
    ) -> RecurrentRolloutBufferSamples:
        
        lstm_states_pi = (
            self.hidden_states_pi[batch_inds],
            self.cell_states_pi[batch_inds]
        )

        lstm_states_vf = (
            self.hidden_states_vf[batch_inds],
            self.cell_states_vf[batch_inds]
        )

        data = (
            self.observations[batch_inds],
            self.actions[batch_inds],
            self.values[batch_inds].flatten(),
            self.log_probs[batch_inds].flatten(),
            self.advantages[batch_inds].flatten(),
            self.returns[batch_inds].flatten(),
            self.dones[batch_inds],
            LSTMStates(pi=lstm_states_pi, vf=lstm_states_vf)
        )
        return RecurrentRolloutBufferSamples(*tuple(map(self.to_torch, data)))


In [12]:
n_envs = 8
n_steps = 128
batch_size = 32
buffer_size = n_envs * n_steps
gamma = 0.99
n_epochs = 4
gae_lambda = 0.95
hidden_size = 64
lstm_state_buffer_shape = (n_steps, n_envs, 64)

env_id = "CartPole-v1"
vec_env = make_vec_env(env_id, n_envs=n_envs)

In [40]:
rollout_buffer = RecurrentRolloutBuffer(
    n_steps,
    vec_env.observation_space,
    vec_env.action_space,
    gamma=gamma,
    gae_lambda=gae_lambda,
    n_envs=n_envs,
    lstm_state_buffer_shape=lstm_state_buffer_shape,
    device="cpu",
)

In [41]:
for i in range(n_steps):

    lstm_states = (
        np.full((n_envs, hidden_size), i, dtype=np.float32),
        np.full((n_envs, hidden_size), i, dtype=np.float32),
    )

    lstm_states = LSTMStates(pi=lstm_states, vf=lstm_states)

    act = np.array([i + 0.1 * idx for idx in range(n_envs)]).reshape(-1, 1)

    rollout_buffer.add(
        obs=np.full((n_envs, 4), i, dtype=np.float32),
        action=act,
        # action=np.full((n_envs, 1), i, dtype=np.float32),
        reward=np.full((n_envs, ), i, dtype=np.float32),
        episode_start=np.zeros((n_envs,), dtype=np.float32),
        value=th.ones((n_envs, 1), dtype=th.float32),
        log_prob=th.ones((n_envs, ), dtype=th.float32),
        dones=np.zeros((n_envs,), dtype=np.float32),
        lstm_states=lstm_states,
    )

In [45]:
print(f"{rollout_buffer.actions.shape = }")
print(f"{rollout_buffer.actions[0] = }")
print(f"{rollout_buffer.actions[1] = }")
print(f"{rollout_buffer.actions[127] = }")

rollout_buffer.actions.shape = (128, 8, 1)
rollout_buffer.actions[0] = array([[0. ],
       [0.1],
       [0.2],
       [0.3],
       [0.4],
       [0.5],
       [0.6],
       [0.7]], dtype=float32)
rollout_buffer.actions[1] = array([[1. ],
       [1.1],
       [1.2],
       [1.3],
       [1.4],
       [1.5],
       [1.6],
       [1.7]], dtype=float32)
rollout_buffer.actions[127] = array([[127. ],
       [127.1],
       [127.2],
       [127.3],
       [127.4],
       [127.5],
       [127.6],
       [127.7]], dtype=float32)


Ok in fact at each iteration I need to differentiate between the current idx (i) and the idx of the environment (n_envs). How can I do that ? 

- pass environment idx as the first digit of the number and then 


In [46]:
# train for n_epochs epochs
# for j in range(n_epochs):
count = 0
for rollout_data in rollout_buffer.get(batch_size):  # type: ignore[attr-defined]
    print(f"{rollout_data.actions = }")
    count += 1

print(f"{count = }")

rollout_data.actions = tensor([[ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 4.],
        [ 5.],
        [ 6.],
        [ 7.],
        [ 8.],
        [ 9.],
        [10.],
        [11.],
        [12.],
        [13.],
        [14.],
        [15.],
        [16.],
        [17.],
        [18.],
        [19.],
        [20.],
        [21.],
        [22.],
        [23.],
        [24.],
        [25.],
        [26.],
        [27.],
        [28.],
        [29.],
        [30.],
        [31.]])
rollout_data.actions = tensor([[32.],
        [33.],
        [34.],
        [35.],
        [36.],
        [37.],
        [38.],
        [39.],
        [40.],
        [41.],
        [42.],
        [43.],
        [44.],
        [45.],
        [46.],
        [47.],
        [48.],
        [49.],
        [50.],
        [51.],
        [52.],
        [53.],
        [54.],
        [55.],
        [56.],
        [57.],
        [58.],
        [59.],
        [60.],
        [61.],
        [62.],
        

- Seems like the batches are quite ok with the custom actions I give 
- Now need to check if it is also the case for the lstm components

In [49]:
# train for n_epochs epochs
# for j in range(n_epochs):
count = 0
for rollout_data in rollout_buffer.get(batch_size):  # type: ignore[attr-defined]
    count += 1
    print(f"{rollout_data.lstm_states[0][0].shape = }")

print(f"{count = }")

# lstm_states_pi = (
#     rollout_data.lstm_states[0][0].numpy().reshape(batch_size, hidden_state_size),
#     rollout_data.lstm_states[0][1].numpy().reshape(batch_size, hidden_state_size)
# )

# lstm_states_vf = (
#     rollout_data.lstm_states[1][0].numpy().reshape(batch_size, hidden_state_size),
#     rollout_data.lstm_states[1][1].numpy().reshape(batch_size, hidden_state_size)
# )

rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])
rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])
rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])
rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])
rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])
rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])
rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])
rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])
rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])
rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])
rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])
rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])
rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])
rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])
rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])
rollout_data.lstm_states[0][0].shape = torch.Size([32, 8])
rollout_data.lstm_states[0][0].shape = torch.Size([32, 8