# ***Flashbax Test*** 

In [25]:
from stoix.systems.q_learning.dqn_types import Transition
import flashbax as fbx
import jax
import jax.numpy as jnp
import gymnax
from dataclasses import dataclass

In [43]:
@dataclass
class Config:
    rollout_length = 3  # Number of environment steps per vectorised environment.
    buffer_size = 500_000  # Total effective size of the replay buffer across all devices and vectorised update steps. This means each device has a buffer of size buffer_size//num_devices which is further divided by the update_batch_size. This value must be divisible by num_devices*update_batch_size.
    batch_size = 256  # Total effective number of samples to train on. This means each device has a batch size of batch_size/num_devices which is further divided by the update_batch_size. This value must be divisible by num_devices*update_batch_size.


config = Config()

In [44]:
def tree_shape(tree):
    return jax.tree_map(lambda x: x.shape, tree)

In [62]:
env, env_params = gymnax.make("CartPole-v1")
init_x = env.reset(jax.random.PRNGKey(0))[0]
dummy_transition = Transition(
    obs=jnp.broadcast_to(init_x, shape=(config.rollout_length, *init_x.squeeze().shape)),
    action=jnp.zeros((config.rollout_length), dtype=int),
    reward=jnp.zeros((config.rollout_length), dtype=float),
    done=jnp.zeros((config.rollout_length), dtype=bool),
    next_obs=jnp.broadcast_to(init_x, shape=(config.rollout_length, *init_x.squeeze().shape)),
    info={
        "episode_return": jnp.zeros((config.rollout_length), dtype=float),
        "episode_length": jnp.zeros((config.rollout_length), dtype=int),
        "is_terminal_step": jnp.zeros((config.rollout_length), dtype=bool),
    },
)


In [63]:
dummy_transition

Transition(obs=Array([[ 0.04653214, -0.02748411,  0.01330299, -0.02036182],
       [ 0.04653214, -0.02748411,  0.01330299, -0.02036182],
       [ 0.04653214, -0.02748411,  0.01330299, -0.02036182]],      dtype=float32), action=Array([0, 0, 0], dtype=int32), reward=Array([0., 0., 0.], dtype=float32), done=Array([False, False, False], dtype=bool), next_obs=Array([[ 0.04653214, -0.02748411,  0.01330299, -0.02036182],
       [ 0.04653214, -0.02748411,  0.01330299, -0.02036182],
       [ 0.04653214, -0.02748411,  0.01330299, -0.02036182]],      dtype=float32), info={'episode_return': Array([0., 0., 0.], dtype=float32), 'episode_length': Array([0, 0, 0], dtype=int32), 'is_terminal_step': Array([False, False, False], dtype=bool)})

In [64]:
tree_shape(dummy_transition)

Transition(obs=(3, 4), action=(3,), reward=(3,), done=(3,), next_obs=(3, 4), info={'episode_length': (3,), 'episode_return': (3,), 'is_terminal_step': (3,)})

In [65]:
buffer_fn = fbx.make_trajectory_buffer(
    max_length_time_axis=config.buffer_size,  # Maximum length of the buffer along the time axis.
    min_length_time_axis=config.batch_size,  # Minimum length across the time axis before we can sample.
    sample_batch_size=config.batch_size,  # Batch size of trajectories sampled from the buffer.
    add_batch_size=config.rollout_length,  # Batch size of trajectories added to the buffer.
    sample_sequence_length=config.rollout_length,  # Sequence length of trajectories sampled from the buffer.
    period=config.rollout_length - 1,  # Period at which we sample trajectories from the buffer.
)
add_sequence_length=config.rollout_length,  # Sequence length of trajectories added to the buffer.

In [66]:
buffer_states = buffer_fn.init(dummy_transition)
tree_shape(buffer_states)

TrajectoryBufferState(experience=Transition(obs=(3, 500000, 3, 4), action=(3, 500000, 3), reward=(3, 500000, 3), done=(3, 500000, 3), next_obs=(3, 500000, 3, 4), info={'episode_length': (3, 500000, 3), 'episode_return': (3, 500000, 3), 'is_terminal_step': (3, 500000, 3)}), current_index=(), is_full=())

In [67]:
dummy_transition

Transition(obs=Array([[ 0.04653214, -0.02748411,  0.01330299, -0.02036182],
       [ 0.04653214, -0.02748411,  0.01330299, -0.02036182],
       [ 0.04653214, -0.02748411,  0.01330299, -0.02036182]],      dtype=float32), action=Array([0, 0, 0], dtype=int32), reward=Array([0., 0., 0.], dtype=float32), done=Array([False, False, False], dtype=bool), next_obs=Array([[ 0.04653214, -0.02748411,  0.01330299, -0.02036182],
       [ 0.04653214, -0.02748411,  0.01330299, -0.02036182],
       [ 0.04653214, -0.02748411,  0.01330299, -0.02036182]],      dtype=float32), info={'episode_return': Array([0., 0., 0.], dtype=float32), 'episode_length': Array([0, 0, 0], dtype=int32), 'is_terminal_step': Array([False, False, False], dtype=bool)})

In [68]:
buffer_states = buffer_fn.add(buffer_states, dummy_transition)

AssertionError: [Chex] Assertion assert_tree_shape_prefix failed: Tree leaf 'action' has a shape of length 1 (shape=(3,)) which is smaller than the expected prefix of length 2 (prefix=(3, 4)).
Tree leaf 'reward' has a shape of length 1 (shape=(3,)) which is smaller than the expected prefix of length 2 (prefix=(3, 4)).
Tree leaf 'done' has a shape of length 1 (shape=(3,)) which is smaller than the expected prefix of length 2 (prefix=(3, 4)).
Tree leaf 'info/episode_length' has a shape of length 1 (shape=(3,)) which is smaller than the expected prefix of length 2 (prefix=(3, 4)).
Tree leaf 'info/episode_return' has a shape of length 1 (shape=(3,)) which is smaller than the expected prefix of length 2 (prefix=(3, 4)).
Tree leaf 'info/is_terminal_step' has a shape of length 1 (shape=(3,)) which is smaller than the expected prefix of length 2 (prefix=(3, 4)).

In [35]:
sample = buffer_fn.sample(buffer_states, jax.random.key(0))
tree_shape(sample)

TrajectoryBufferSample(experience=Transition(obs=(85, 1, 3, 4), action=(85, 1, 3), reward=(85, 1, 3), done=(85, 1, 3), next_obs=(85, 1, 3, 4), info={'episode_length': (85, 1, 3), 'episode_return': (85, 1, 3), 'is_terminal_step': (85, 1, 3)}))