# ***Flashbax Test*** 

In [3]:
from stoix.systems.q_learning.dqn_types import Transition
import flashbax as fbx
import jax
import jax.numpy as jnp
import gymnax
from dataclasses import dataclass

In [4]:
@dataclass
class Config:
    rollout_length = 3  # Number of environment steps per vectorised environment.
    buffer_size = 500_000  # Total effective size of the replay buffer across all devices and vectorised update steps. This means each device has a buffer of size buffer_size//num_devices which is further divided by the update_batch_size. This value must be divisible by num_devices*update_batch_size.
    batch_size = 256  # Total effective number of samples to train on. This means each device has a batch size of batch_size/num_devices which is further divided by the update_batch_size. This value must be divisible by num_devices*update_batch_size.


config = Config()

In [5]:
def tree_shape(tree):
    return jax.tree_map(lambda x: x.shape, tree)

In [6]:
env, env_params = gymnax.make("CartPole-v1")
init_x = env.reset(jax.random.PRNGKey(0))[0]
init_x = jax.tree_util.tree_map(lambda x: x[None, ...], init_x)
dummy_transition = Transition(
    obs=jax.tree_util.tree_map(lambda x: x.squeeze(0), init_x),
    action=jnp.zeros((), dtype=int),
    reward=jnp.zeros((), dtype=float),
    done=jnp.zeros((), dtype=bool),
    next_obs=jax.tree_util.tree_map(lambda x: x.squeeze(0), init_x),
    info={
        "episode_return": jnp.zeros((config.rollout_length), dtype=float),
        "episode_length": jnp.zeros((config.rollout_length), dtype=int),
        "is_terminal_step": jnp.zeros((config.rollout_length), dtype=bool),
    },
)

In [7]:
dummy_transition

Transition(obs=Array([ 0.04653214, -0.02748411,  0.01330299, -0.02036182], dtype=float32), action=Array(0, dtype=int32), reward=Array(0., dtype=float32), done=Array(False, dtype=bool), next_obs=Array([ 0.04653214, -0.02748411,  0.01330299, -0.02036182], dtype=float32), info={'episode_return': Array([0., 0., 0.], dtype=float32), 'episode_length': Array([0, 0, 0], dtype=int32), 'is_terminal_step': Array([False, False, False], dtype=bool)})

In [8]:
tree_shape(dummy_transition)

Transition(obs=(4,), action=(), reward=(), done=(), next_obs=(4,), info={'episode_length': (3,), 'episode_return': (3,), 'is_terminal_step': (3,)})

In [46]:
add_batch_size=1
buffer_fn = fbx.make_trajectory_buffer(
    max_length_time_axis=config.buffer_size,  # Maximum length of the buffer along the time axis.
    min_length_time_axis=config.batch_size,  # Minimum length across the time axis before we can sample.
    sample_batch_size=config.batch_size,  # Batch size of trajectories sampled from the buffer.
    add_batch_size=add_batch_size,  # Batch size of trajectories added to the buffer.
    sample_sequence_length=config.rollout_length,  # Sequence length of trajectories sampled from the buffer.
    period=config.rollout_length,  # Period at which we sample trajectories from the buffer.
)

In [47]:
buffer_states = buffer_fn.init(dummy_transition)
tree_shape(buffer_states)

TrajectoryBufferState(experience=Transition(obs=(1, 500000, 4), action=(1, 500000), reward=(1, 500000), done=(1, 500000), next_obs=(1, 500000, 4), info={'episode_length': (1, 500000, 3), 'episode_return': (1, 500000, 3), 'is_terminal_step': (1, 500000, 3)}), current_index=(), is_full=())

## **Training Scenario**

Batch dimension of the warmup batch == add_batch_size and transitions have shape (rollout_length, *)

In [48]:
broadcast_fn = lambda x: jnp.broadcast_to(x, (add_batch_size, config.rollout_length, *x.shape))
fake_batch_sequence = jax.tree_map(broadcast_fn, dummy_transition)

In [49]:
tree_shape(fake_batch_sequence)

Transition(obs=(1, 3, 4), action=(1, 3), reward=(1, 3), done=(1, 3), next_obs=(1, 3, 4), info={'episode_length': (1, 3, 3), 'episode_return': (1, 3, 3), 'is_terminal_step': (1, 3, 3)})

In [51]:
buffer_states = buffer_fn.add(buffer_states, fake_batch_sequence)

In [52]:
sample = buffer_fn.sample(buffer_states, jax.random.key(0))
tree_shape(sample)

TrajectoryBufferSample(experience=Transition(obs=(256, 3, 4), action=(256, 3), reward=(256, 3), done=(256, 3), next_obs=(256, 3, 4), info={'episode_length': (256, 3, 3), 'episode_return': (256, 3, 3), 'is_terminal_step': (256, 3, 3)}))

## **Warmup Scenario**

Batch dimension of the warmup batch != add_batch_size and transitions have shape (1, *) instead of (rollout_length, *)

In [72]:
def buffer_seq_add(buffer_states, transition_batch: Transition, config: Config = config):
    """
    Sequentially adds transitions from a warmup Transition batch to the replay
    buffer to match the `add_batch_size` requirement of the TrajectoryBuffer.
    Assumes n_warmup_steps % rollout_length == 0.
    """

    def _add_single_transition(buffer_states, transition):
        transition = jax.tree_map(lambda x:jnp.expand_dims(x, axis=0), transition)
        buffer_states = buffer_fn.add(buffer_states, transition)
        return buffer_states, None

    def _reshape_fn(x: jnp.ndarray):
        """
        Reshape a batch of transition with shape (n_agents, n_warmup_steps, *x)
        to (-1, rollout_length, *x) to be added to the replay buffer.
        """
        return x.reshape(-1, config.rollout_length, *x.shape[2:])

    reshaped_batch = jax.tree_util.tree_map(_reshape_fn, transition_batch)
    # return reshaped_batch
    buffer_states, _ = jax.lax.scan(
        lambda buffer_states, transition: _add_single_transition(buffer_states, transition),
        buffer_states,
        reshaped_batch,
    )

    return buffer_states

w_broadcast_fn = lambda x: jnp.broadcast_to(x, (1024, 33, *x.shape)) # (n_envs, n_warmup_steps, *x_shape)
w_fake_batch_sequence = jax.tree_map(w_broadcast_fn, dummy_transition)

buffer_states = buffer_seq_add(buffer_states, w_fake_batch_sequence, config)

tree_shape(buffer_states)

TrajectoryBufferState(experience=Transition(obs=(1, 500000, 4), action=(1, 500000), reward=(1, 500000), done=(1, 500000), next_obs=(1, 500000, 4), info={'episode_length': (1, 500000, 3), 'episode_return': (1, 500000, 3), 'is_terminal_step': (1, 500000, 3)}), current_index=(), is_full=())

In [73]:
sample = buffer_fn.sample(buffer_states, jax.random.key(0))
tree_shape(sample)

TrajectoryBufferSample(experience=Transition(obs=(256, 3, 4), action=(256, 3), reward=(256, 3), done=(256, 3), next_obs=(256, 3, 4), info={'episode_length': (256, 3, 3), 'episode_return': (256, 3, 3), 'is_terminal_step': (256, 3, 3)}))