# <h1><center>Q-learning Tutorial</center></h1>

This notebook provides an introductory tutorial to Q-learning. Specifically, we will implement Q-Learning using a Q-table in JAX and use it to steer a simplified brittle star robot towards a random target. The [brittle star robot and its environment](https://github.com/Co-Evolve/brb/tree/new-framework/brb/brittle_star) is part of the [**the Bio-inspired Robotics Benchmark (BRB)**](https://github.com/Co-Evolve/brb). Instead of directly outputting joint-level actions, we will use our Q-learned controller to modulate a CPG that in turn outputs the joint-level actions.


## Q-Learning

* Belongs to the class of model-free algorithms, meaning that it does not require prior knowledge (or a model) of the environment.
* Is an off-policy algorithm, meaning that it does not necessarily use the 'current policy' to produce actions
* As it name entails, the goal of the algorithm is to learn the Q-function
    * The Q-function receives state and action pairs, and (tries to) return the expected cumulative reward
        * In other words, it tries to predict the expected payoff of doing a certain action in a given state
* In this tutorial we will focus on tabular Q-learning
    * We will optimize the values of a two-dimensional table with states as rows and actions as columns.
    * Each cell in this table corresponds to the Q value of a state and action
        * As we use a table, the state and actions have to be discretized.
    * Initially, this Q-table will be populated with arbitrary values or zeros
    * The q-learning algorithm then tries to optimize the values in this table through an iterative process of exploration and exploitation.
        * During the exploration phase, the agent will take random actions to gather information about the environment and update the Q-table accordingly
        * As the agent explores more, it gradually transitions to the exploitation phase, where it leverages the learned Q-values to make more informed decisions and maximize the cumulative reward
    * The Q-learning algorithm can be summarized as follows:
        1. Initialize the Q-table with arbitrary values or zeros
        2. Observe the current state of the environment
        3. Choose an action to take based on a exploration-exploitation trade-off. This can for instance be done by using an exploration strategy like epsilon-greedy, where the agent selects a random action with a certain probability and chooses the action with the highest Q-value with a complementary probability.
        4. Perform the chosen action and observe the reward and the resulting next state
        5. Update the Q-value of the state-action pair using the Q-learning update rule:<br>
        $Q(s,a) = (1 - \alpha) Q(s,a) + \alpha(r + \gamma\max(Q(s', a')))$<br>
                where alpha $\alpha$ is the learning rate, gamma $\gamma$ is the discount factor that determines the importance of future rewards, $r$ is the immediate reward obtained (given by the environment), and $\max(Q(s', a'))$ represents the maximum Q-value for the next state. 
        6. Repeat steps $2$ to $5$ until convergence or a predefined number of iterations.
* A natural extension of the tabular Q-learning algorithm is the Deep-QLearning (DQN) algorithm. As it name gives away, the DQN algorithm swaps the Q-table for a deep Q-neural-network. This neural network maps states to Q(s, a) values. 

* This tutorial will focus on Q-Table learning 

## Implementing tabular Q-learning in JAX

When implementing something in JAX it's important to remember that JAX follows the functional programming paradigm. Put simply, we thus rely on pure functions (deterministic and without side effects) and immutable data structures (instead of changing data in place, new data structures are created with the desired modifications) as primary building blocks.

We will thus start by creating a data structure that we will use to hold our current Q-Learner's state and its related learning (hyper)parameters.

In [8]:
import random

import gymnasium
from flax import struct
import jax
import jax.numpy as jnp


@struct.dataclass
class QLearnerState:
    q_table: jnp.ndarray
    alpha: float
    epsilon: float
    gamma: float


def initialize_q_learning_state(
        num_states: int,
        num_actions: int,
        alpha: float,
        epsilon: float,
        gamma: float,
        rng: jnp.ndarray
        ) -> QLearnerState:
    # noinspection PyArgumentList
    return QLearnerState(
            q_table=jax.random.uniform(
                    key=rng, shape=(num_states, num_actions), dtype=jnp.float32, minval=-0.001, maxval=0.001
                    ), alpha=alpha, epsilon=epsilon, gamma=gamma
            )

The `QLearnerState` will be the main data structure that we'll pass around between functions.
Next, we can implement a function that receives the current `QLearnerState`, an environment state index (which represents the environment's state in a discretized manner), an action index, and a reward, and in turn returns an updated `QLearnerState` after applying the Q-learning update rule. 

In [9]:
def apply_q_learning_update_rule(
        q_learner_state: QLearnerState,
        state_index: int,
        next_state_index: int,
        action_index: int,
        reward: float
        ) -> QLearnerState:
    old_q_value = q_learner_state.q_table[state_index, action_index]
    best_future_q_value = jnp.max(q_learner_state.q_table[next_state_index])
    q_value_update = reward + q_learner_state.gamma * best_future_q_value
    new_q_value = (1 - q_learner_state.alpha) * old_q_value + q_learner_state.alpha * q_value_update

    new_q_table = q_learner_state.q_table.at[state_index, action_index].set(new_q_value)
    return q_learner_state.replace(q_table=new_q_table)

Next, we can implement a function that represents our policy. This function will receive the current `QLeanerState`, uses the epsilon-greedy strategy to handle the trade-off between exploration and exploitation, and returns an action index.

In [10]:
def epsilon_greedy_policy(
        q_learner_state: QLearnerState,
        rng: jnp.ndarray,
        state_index: int
        ) -> int:
    explore_rng, random_action_rng = jax.random.split(rng, 2)
    explore = jax.random.uniform(explore_rng) < q_learner_state.epsilon

    def get_random_action() -> int:
        return jax.random.choice(key=random_action_rng, a=jnp.arange(q_learner_state.q_table.shape[1]))

    def get_greedy_action() -> int:
        return jnp.argmax(q_learner_state.q_table[state_index])

    action_index = jax.lax.cond(
            pred=explore, true_fun=get_random_action, false_fun=get_greedy_action
            )

    return action_index

A general training loop:

In [11]:
from typing import Any, Callable, Dict, Tuple
from tqdm import tqdm
from mujoco_utils.environment.mjx_env import MJXGymEnvWrapper

ActionMapperState = Any

def train_q_learning_agent(
        q_learner_state: QLearnerState,
        num_episodes: int,
        env: MJXGymEnvWrapper,
        seed: int,
        state_indexer: Callable[[Dict[str, jnp.ndarray]], int],
        action_mapper: Callable[[ActionMapperState, int], Tuple[ActionMapperState, jnp.ndarray]],
        action_mapper_state: ActionMapperState
        ) -> QLearnerState:
    key = jax.random.PRNGKey(seed=seed)

    for _ in tqdm(range(num_episodes)):
        obs, info = env.reset()
        state_index = state_indexer(obs)

        done = False
        while not done:
            key, action_rng = jax.random.split(key, 2)

            action_index = epsilon_greedy_policy(
                    q_learner_state=q_learner_state, rng=action_rng, state_index=state_index
                    )
            action_mapper_state, actions = action_mapper(action_mapper_state, action_index)
            obs, reward, terminated, truncated, info = env.step(actions=actions)

            new_state_index = state_indexer(obs)
            q_learner_state = apply_q_learning_update_rule(
                    q_learner_state=q_learner_state,
                    state_index=state_index,
                    next_state_index=new_state_index,
                    action_index=action_index,
                    reward=reward
                    )

            done = jnp.any(terminated | truncated)
            state_index = new_state_index

    return q_learner_state

We can also create a general evaluation loop that we can use later to visualize our learned policy:

In [13]:
import numpy as np
from typing import List


def visualize_q_learning_agent(
        q_learner_state: QLearnerState,
        env: MJXGymEnvWrapper,
        seed: int,
        state_indexer: Callable[[Dict[str, jnp.ndarray]], int],
        action_mapper: Callable[[int], jnp.ndarray]
        ) -> List[np.ndarray]:
    key = jax.random.PRNGKey(seed=seed)

    obs, info = env.reset()
    state_index = state_indexer(obs)

    frames = []
    done = False
    while not done:
        key, action_rng = jax.random.split(key, 2)

        action_index = epsilon_greedy_policy(
                q_learner_state=q_learner_state, rng=action_rng, state_index=state_index
                )
        actions = action_mapper(action_index)
        obs, reward, terminated, truncated, info = env.step(actions=actions)

        frames.append(env.render())
        done = jnp.any(terminated | truncated)

    return frames

Great, we have implemented the tabular Q-learning algorithm. Time to test it out with the brittle star environment!

## Case study: CPG modulations for directed brittle star locomotion

### Environment setup
* Load BRB's brittle star environment -> targeted locomotion
* Create a state indexer
* Create an action mapper 

In [33]:
from mujoco_utils.environment.base import MuJoCoEnvironmentConfiguration
from typing import Union
import mediapy as media
from brb.brittle_star.environment.target.mjx_env import BrittleStarTargetMJXEnvironment
from brb.brittle_star.mjcf.morphology.morphology import MJCFBrittleStarMorphology
from brb.brittle_star.mjcf.morphology.specification.default import default_brittle_star_morphology_specification
from brb.brittle_star.mjcf.arena.aquarium import AquariumArenaConfiguration, MJCFAquariumArena
from brb.brittle_star.environment.target.mjc_env import BrittleStarTargetEnvironmentConfiguration

morphology_specification = default_brittle_star_morphology_specification(
        num_arms=5, num_segments_per_arm=[2, 0, 2, 0, 0], use_p_control=True, use_torque_control=False
        )
arena_configuration = AquariumArenaConfiguration(
        size=(10, 5), sand_ground_color=False, attach_target=True, wall_height=1.5, wall_thickness=0.1
        )
environment_configuration = BrittleStarTargetEnvironmentConfiguration(
        # Distance to put our target at (targets are spawned on a circle around the starting location with this given radius).
        target_distance=3.0,
        joint_randomization_noise_scale=0.0,
        render_mode="rgb_array",
        simulation_time=10,
        num_physics_steps_per_control_step=10,
        camera_ids=[0, 1]
        )

def create_environment(
        num_environments: int = 1
        ) -> MJXGymEnvWrapper:
    morphology = MJCFBrittleStarMorphology(
            specification=morphology_specification
            )
    arena = MJCFAquariumArena(
            configuration=arena_configuration
            )
    env = BrittleStarTargetMJXEnvironment(morphology=morphology, arena=arena, configuration=environment_configuration)
    return MJXGymEnvWrapper(env=env, num_envs=num_environments)


def post_environment_render(
        render_output: Union[np.ndarray, List[np.ndarray]],
        ) -> np.ndarray:
    if len(environment_configuration.camera_ids) > 1:
        render_output = np.concatenate(render_output, axis=1)
    
    return render_output[:, :, ::-1]  # RGB to BGR

In [34]:
env = create_environment(num_environments=1)
print("Observation space:")
print(env.observation_space)
print()
print("Action space:")
print(env.action_space)
env.reset()
media.show_image(post_environment_render(env.render()))
env.close()

Observation space:
Dict('disk_angular_velocity': Box(-inf, inf, (3,), float32), 'disk_linear_velocity': Box(-inf, inf, (3,), float32), 'disk_position': Box(-inf, inf, (3,), float32), 'disk_rotation': Box(-3.1415927, 3.1415927, (3,), float32), 'in_plane_joint_position': Box(-0.5235988, 0.5235988, (4,), float32), 'in_plane_joint_velocity': Box(-inf, inf, (4,), float32), 'out_of_plane_joint_position': Box(-0.5235988, 0.5235988, (4,), float32), 'out_of_plane_joint_velocity': Box(-inf, inf, (4,), float32), 'segment_contact': Box(0.0, 1.0, (4,), float32), 'unit_xy_direction_to_target': Box(-1.0, 1.0, (2,), float32), 'xy_distance_to_target': Box(0.0, inf, (1,), float32))

Action space:
Box(-0.5235988, 0.5235988, (8,), float32)


### CPG setup

* CPG model from Sproewitz paper
* 2 bi-directionally connected oscillators per arm
    * One for in plane and one for out-of plane
        * these should be in anti-phase (in plane forward should align with out of plane upward)
    * we will modulate the amplitude
    * every segment of one arm will use the same action 
* In plane oscillator of arm 1 is connected to in plane oscillator of arm 2 (same for out of plane)
    * we keep these in the same phase
* Show image of oscillator to body mapping

In [44]:
# Constants
NUM_OSCILLATORS = 4 # [arm1_in_plane, arm1_out_of_plane, arm2_in_plane, arm2_out_of_plane]
AMPLITUDE_GAIN = 20
OFFSET_GAIN = 20
ADJACENCY_MATRIX = jnp.array([
        [0, 1, 1, 0],
        [1, 0, 0, 1],
        [1, 0, 0, 1],
        [0, 1, 1, 0]
        ]) 
WEIGHTS = ADJACENCY_MATRIX * 5.0
OFFSETS = jnp.zeros(NUM_OSCILLATORS)
anti_phase_oscillator_pairs = [(0, 1), (2, 3), (0, 2)]
RHOS = jnp.zeros_like(ADJACENCY_MATRIX, dtype=jnp.float32)
for i, j in anti_phase_oscillator_pairs:
    RHOS = RHOS.at[i, j].set(jnp.pi)
OMEGAS = jnp.ones(NUM_OSCILLATORS) * 2 * jnp.pi

Copied CPG implementation from the CPG tutorial:

In [45]:
@struct.dataclass
class CPGState:
    time: float
    phases: jnp.ndarray
    dot_amplitudes: jnp.ndarray  # first order derivative of the amplitude
    amplitudes: jnp.ndarray
    dot_offsets: jnp.ndarray  # first order derivative of the offset 
    offsets: jnp.ndarray
    outputs: jnp.ndarray

# Initial conditions
def get_random_initial_state(
        rng: jnp.ndarray
        ) -> CPGState:
    phase_rng, amplitude_rng, offsets_rng = jax.random.split(rng, 3) 
    # noinspection PyArgumentList
    state = CPGState(
            phases=jax.random.uniform(
                    key=phase_rng, shape=(NUM_OSCILLATORS,), dtype=jnp.float32, minval=-jnp.pi, maxval=jnp.pi
                    ),
            amplitudes=jax.random.uniform(
                    key=phase_rng, shape=(NUM_OSCILLATORS,), dtype=jnp.float32, minval=0, maxval=1
                    ),
            offsets=jax.random.uniform(key=phase_rng, shape=(NUM_OSCILLATORS,), dtype=jnp.float32, minval=-1, maxval=1),
            dot_amplitudes=jnp.zeros(ADJACENCY_MATRIX.shape[0]),
            dot_offsets=jnp.zeros(ADJACENCY_MATRIX.shape[0]),
            outputs=jnp.zeros(ADJACENCY_MATRIX.shape[0]),
            time=0.0
            )
    return state

def rk4_solver(
        current_time: float,
        y: float,
        derivative_fn: Callable[[float, float], float],
        delta_time: float
        ) -> float:
    # This is the original euler
    slope1 = derivative_fn(current_time, y)
    # These are additional slope calculations that improve our approximation of the true slope  
    slope2 = derivative_fn(current_time + delta_time / 2, y + slope1 * delta_time / 2)
    slope3 = derivative_fn(current_time + delta_time / 2, y + slope2 * delta_time / 2)
    slope4 = derivative_fn(current_time + delta_time, y + slope3 * delta_time)
    average_slope = (slope1 + 2 * slope2 + 2 * slope3 + slope4) / 6
    next_y = y + average_slope * delta_time
    return next_y

def phase_de(
        omegas: jnp.ndarray,
        amplitudes: jnp.ndarray,
        phases: jnp.ndarray,
        phase_biases: jnp.ndarray
        ) -> jnp.ndarray:
    @jax.vmap  # vectorizes this function for us over an additional batch dimension (in this case over all oscillators)
    def sine_term(
            phase_i: float,
            phase_biases_i: float
            ) -> jnp.ndarray:
        return jnp.sin(phases - phase_i - phase_biases_i)

    couplings = jnp.sum(WEIGHTS * amplitudes * sine_term(phase_i=phases, phase_biases_i=phase_biases), axis=1)
    return omegas + couplings

def second_order_de(
        gain: jnp.ndarray,
        modulator: jnp.ndarray,
        values: jnp.ndarray,
        dot_values: jnp.ndarray
        ) -> jnp.ndarray:
    return gain * ((gain / 4) * (modulator - values) - dot_values)


def first_order_de(
        dot_values: jnp.ndarray
        ) -> jnp.ndarray:
    return dot_values
def output(
        offsets: jnp.ndarray,
        amplitudes: jnp.ndarray,
        phases: jnp.ndarray
        ) -> jnp.ndarray:
    return offsets + amplitudes * jnp.cos(phases)

# Remember: derivative_fn expects (time, y)
def step_cpg(
        delta_time: float,
        state: CPGState,
        solver: Callable,
        # Modulation parameters
        omegas: jnp.ndarray,
        R: jnp.ndarray,
        X: jnp.ndarray,
        rhos: jnp.ndarray
        ) -> CPGState:
    # Update phase
    new_phases = solver(
            current_time=state.time,
            y=state.phases,
            derivative_fn=lambda
                t,
                y: phase_de(
                    omegas=omegas, amplitudes=state.amplitudes, phases=y, phase_biases=rhos
                    ),
            delta_time=delta_time
            )
    new_dot_amplitudes = solver(
            current_time=state.time,
            y=state.dot_amplitudes,
            derivative_fn=lambda
                t,
                y: second_order_de(gain=AMPLITUDE_GAIN, modulator=R, values=state.amplitudes, dot_values=y),
            delta_time=delta_time
            )
    new_amplitudes = solver(
            current_time=state.time,
            y=state.amplitudes,
            derivative_fn=lambda
                t,
                y: first_order_de(dot_values=state.dot_amplitudes),
            delta_time=delta_time
            )
    new_dot_offsets = solver(
            current_time=0,
            y=state.dot_offsets,
            derivative_fn=lambda
                t,
                y: second_order_de(gain=OFFSET_GAIN, modulator=X, values=state.offsets, dot_values=y),
            delta_time=delta_time
            )
    new_offsets = solver(
            current_time=0,
            y=state.offsets,
            derivative_fn=lambda
                t,
                y: first_order_de(dot_values=state.dot_offsets),
            delta_time=delta_time
            )

    new_outputs = output(offsets=new_offsets, amplitudes=new_amplitudes, phases=new_phases)
    # noinspection PyUnresolvedReferences
    return state.replace(
            phases=new_phases,
            dot_amplitudes=new_dot_amplitudes,
            amplitudes=new_amplitudes,
            dot_offsets=new_dot_offsets,
            offsets=new_offsets,
            outputs=new_outputs,
            time=state.time + delta_time
            )


Connecting the CPG to the environment
* Make sure actions are within the given action space range
* Match times (i.e. CPG has a timeline and env has a timeline -> Make it match) 

In [ ]:
env = create_environment(1)

### Exploit JAX: vectorize!
Todo: compare training times