# ***Flashbax Test*** 

In [1]:
from stoix.systems.q_learning.dqn_types import Transition
import flashbax as fbx
import jax
import jax.numpy as jnp
import gymnax
from dataclasses import dataclass

In [2]:
@dataclass
class Config:
    rollout_length = 3  # Number of environment steps per vectorised environment.
    buffer_size = 256  # Total effective size of the replay buffer across all devices and vectorised update steps. This means each device has a buffer of size buffer_size//num_devices which is further divided by the update_batch_size. This value must be divisible by num_devices*update_batch_size.
    batch_size = 16  # Total effective number of samples to train on. This means each device has a batch size of batch_size/num_devices which is further divided by the update_batch_size. This value must be divisible by num_devices*update_batch_size.


def tree_shape(tree):
    return jax.tree_map(lambda x: x.shape, tree)


config = Config()

In [4]:
env, env_params = gymnax.make("CartPole-v1")
# init_x = env.reset(jax.random.PRNGKey(0))[0]
init_x = jnp.arange(4)
init_x = jax.tree_util.tree_map(lambda x: x[None, ...], init_x)
dummy_transition = Transition(
    obs=jax.tree_util.tree_map(lambda x: x.squeeze(0), init_x),
    action=jnp.zeros((), dtype=int),
    reward=jnp.zeros((), dtype=float),
    done=jnp.zeros((), dtype=bool),
    next_obs=jax.tree_util.tree_map(lambda x: x.squeeze(0), init_x),
    info={
        "episode_return": jnp.zeros((config.rollout_length), dtype=float),
        "episode_length": jnp.zeros((config.rollout_length), dtype=int),
        "is_terminal_step": jnp.zeros((config.rollout_length), dtype=bool),
    },
)
dummy_transition

Transition(obs=Array([0, 1, 2, 3], dtype=int32), action=Array(0, dtype=int32), reward=Array(0., dtype=float32), done=Array(False, dtype=bool), next_obs=Array([0, 1, 2, 3], dtype=int32), info={'episode_return': Array([0., 0., 0.], dtype=float32), 'episode_length': Array([0, 0, 0], dtype=int32), 'is_terminal_step': Array([False, False, False], dtype=bool)})

In [5]:
tree_shape(dummy_transition)

Transition(obs=(4,), action=(), reward=(), done=(), next_obs=(4,), info={'episode_length': (3,), 'episode_return': (3,), 'is_terminal_step': (3,)})

In [6]:
add_batch_size=3
buffer_fn = fbx.make_trajectory_buffer(
    max_length_time_axis=config.buffer_size,  # Maximum length of the buffer along the time axis.
    min_length_time_axis=config.batch_size,  # Minimum length across the time axis before we can sample.
    sample_batch_size=config.batch_size,  # Batch size of trajectories sampled from the buffer.
    add_batch_size=config.rollout_length,  # Batch size of trajectories added to the buffer.
    sample_sequence_length=config.rollout_length,  # Sequence length of trajectories sampled from the buffer.
    period=config.rollout_length,  # Period at which we sample trajectories from the buffer.
)

In [7]:
buffer_states = buffer_fn.init(dummy_transition)
tree_shape(buffer_states)

TrajectoryBufferState(experience=Transition(obs=(3, 256, 4), action=(3, 256), reward=(3, 256), done=(3, 256), next_obs=(3, 256, 4), info={'episode_length': (3, 256, 3), 'episode_return': (3, 256, 3), 'is_terminal_step': (3, 256, 3)}), current_index=(), is_full=())

## **Training Scenario**

Batch dimension of the warmup batch == add_batch_size and transitions have shape (rollout_length, *)

In [8]:
broadcast_fn = lambda x: jnp.broadcast_to(x, (config.rollout_length, *x.shape))
fake_batch_sequence = jax.tree_map(broadcast_fn, dummy_transition)

In [9]:
tree_shape(fake_batch_sequence)

Transition(obs=(3, 4), action=(3,), reward=(3,), done=(3,), next_obs=(3, 4), info={'episode_length': (3, 3), 'episode_return': (3, 3), 'is_terminal_step': (3, 3)})

In [10]:
expanded_batched_sequence = jax.tree_map(lambda x: jnp.expand_dims(x, axis=1), fake_batch_sequence)
buffer_states = buffer_fn.add(buffer_states, expanded_batched_sequence)

In [11]:
sample = buffer_fn.sample(buffer_states, jax.random.key(0))
tree_shape(sample)

TrajectoryBufferSample(experience=Transition(obs=(16, 3, 4), action=(16, 3), reward=(16, 3), done=(16, 3), next_obs=(16, 3, 4), info={'episode_length': (16, 3, 3), 'episode_return': (16, 3, 3), 'is_terminal_step': (16, 3, 3)}))

In [12]:
sample.experience.obs[0]

Array([[0, 0, 0, 0],
       [0, 0, 0, 0],
       [0, 1, 2, 3]], dtype=int32)

## **Warmup Scenario**

Batch dimension of the warmup batch != add_batch_size and transitions have shape (1, *) instead of (rollout_length, *)

In [20]:
def buffer_seq_add(buffer_states, transition_batch: Transition, config: Config = config):
    """
    Sequentially adds transitions from a warmup Transition batch to the replay
    buffer to match the `add_batch_size` requirement of the TrajectoryBuffer.
    Assumes n_warmup_steps % rollout_length == 0.
    """

    def _add_single_transition(buffer_states, transition):
        # add seq_len=1
        transition = jax.tree_map(lambda x:jnp.expand_dims(x, axis=1), transition)
        buffer_states = buffer_fn.add(buffer_states, transition)
        return buffer_states, None

    def _reshape_fn(x: jnp.ndarray):
        """
        Reshape a batch of transition with shape (n_agents, n_warmup_steps, *x)
        to (-1, rollout_length, *x) to be added to the replay buffer.
        """
        return x.reshape(-1, config.rollout_length, *x.shape[2:])

    reshaped_batch = jax.tree_util.tree_map(_reshape_fn, transition_batch)
    # return reshaped_batch
    print(tree_shape(reshaped_batch))
    buffer_states, _ = jax.lax.scan(
        lambda buffer_states, transition: _add_single_transition(buffer_states, transition),
        buffer_states,
        reshaped_batch,
    )

    return buffer_states

# (n_envs, n_warmup_steps, *x_shape), breaks for n_warmup_steps such that n_warmup_steps % rollout_length != 0
w_broadcast_fn = lambda x: jnp.broadcast_to(x, (128, 24, *x.shape)) 
w_fake_batch_sequence = jax.tree_map(w_broadcast_fn, dummy_transition)

buffer_states = buffer_seq_add(buffer_states, w_fake_batch_sequence, config)

tree_shape(buffer_states)

Transition(obs=(1024, 3, 4), action=(1024, 3), reward=(1024, 3), done=(1024, 3), next_obs=(1024, 3, 4), info={'episode_length': (1024, 3, 3), 'episode_return': (1024, 3, 3), 'is_terminal_step': (1024, 3, 3)})


TrajectoryBufferState(experience=Transition(obs=(3, 256, 4), action=(3, 256), reward=(3, 256), done=(3, 256), next_obs=(3, 256, 4), info={'episode_length': (3, 256, 3), 'episode_return': (3, 256, 3), 'is_terminal_step': (3, 256, 3)}), current_index=(), is_full=())

In [17]:
tree_shape(w_fake_batch_sequence)

Transition(obs=(128, 24, 4), action=(128, 24), reward=(128, 24), done=(128, 24), next_obs=(128, 24, 4), info={'episode_length': (128, 24, 3), 'episode_return': (128, 24, 3), 'is_terminal_step': (128, 24, 3)})

In [15]:
sample = buffer_fn.sample(buffer_states, jax.random.key(0))
tree_shape(sample)

TrajectoryBufferSample(experience=Transition(obs=(16, 3, 4), action=(16, 3), reward=(16, 3), done=(16, 3), next_obs=(16, 3, 4), info={'episode_length': (16, 3, 3), 'episode_return': (16, 3, 3), 'is_terminal_step': (16, 3, 3)}))