
#### Connect, learn and contribute to help yourself and others land a job in the AI space

Looking for a way to contribute or learn more about AI/ML, connect with me on medium:
- LinkedIn: [https://www.linkedin.com/in/stefruinard/]()
- Medium: [https://medium.com/@stefruinard]()
- GitHub: [https://github.com/Sruinard]()


# Machine learning with JAX part IV
Welcome to the fourth and final blog in this series on learning the fundamentals of JAX. Although all parts can be read seperately, you might want to check out some of the earlier blogs in this series, which you find [here](), [here]() and [here](). With that out of the way, welcome at SpaceY, we're thrilled to have you! After some intense onboarding SpaceY has set you up for the next challenge: Landing a rocket safely on the surface of the moon using reinforcement learning, JAX and [Haiku](https://dm-haiku.readthedocs.io/en/latest/). Are you ready for this next adventure? Great! Lets land some rockets.


## What is reinforcement learning
Our main goal will be to introduce you to the wonderful world of JAX (and Haiku in this case), but where is the fun in sticking to just that? So instead of training a model on the MNIST dataset, lets do something slightly more complex: Training a reinforcement learning agent to land a rocket on the moon. Reinforcement learning is an exciting 'new' area in the machine learning field. Although it still is mostly used in research, more and more fascinating applications containing reinforcement learning are deployed in practice: [RL at Deepmind](https://www.deepmind.com/blog/deep-reinforcement-learning) and [RL at Microsoft](https://blogs.microsoft.com/ai/reinforcement-learning/).

In case you are not familiar with RL, here is the gist of it.
In reinforcement learning we have an agent (sometimes referred to as a policy) which can take a set of actions in an environment. For example, you can have a game which is the environment and a player who can walk arount in that game environment as the agent. Everytime the agent takes an action, it 'observes' the environment and finds itself in a new state (you could also say it observes the state it is in). Now, given the action taken, it also receives a reward or a punishment (this is where the reinforcement part in reinforcement learning is coming from) which gives a signal to the agent about how happy to be with the action taken which closes the feedback loop. Alright, that are quite some moving components. We have Agents, Actions, Environment, Rewards, Observations/States. To make it slightly more intuitive, the following visual overview will probably help:

![alt](https://github.com/Sruinard/machine_learning_novice/blob/main/assets/ml_with_jax_part_4/reinforcement_learning_concepts.png?raw=true)

Lets map the terms we just introduced (e.g. agent, action, reward, state), to an overview containing terms you might be more familiar with:

![](https://github.com/Sruinard/machine_learning_novice/blob/main/assets/ml_with_jax_part_4/reinforcement_learning_concepts_with_familiar_terms.png?raw=true)

So what's next? What is actually that we try to achieve? Well, we want to maximize our rewards and use these rewards to train our neural network. So for that, we need some data!

# The dataset in reinforcement learning. Not your typical dataset.

In the previous parts we worked with a static dataset. For training our RL-agent we'll be generating a dataset through simulation. Each time our agent interacts with the environment, we'll register the current state, the action, the next state, the reward and whether it is the end of the game. We'll save this to the agent's memory which we set to 10000 (you'll see this later in code). What will happen is that once the agents learns more and more about the environment, our experiences will include higher rewards from which the agent can learn again. It's basically teaching itself. Be aware that since the agent's memory is limited, older experiences are pushed out and replaced by the new experiences. ![Building our dataset](../assets/ml_with_jax_part_4/dataset_rl.png)


In [3]:
import jax
import jaxlib
import jax.numpy as jnp
import optax
import haiku as hk
import gym

  PyTreeDef = type(jax.tree_structure(None))


## Creating the environment and inspecting the action space
We'll leverage the gym package created by [OpenAI](https://openai.com/). In our case we'll use it to create the LunarLander environment with which our RL-agent will interact. So lets do that! Lets create the environment. 

In [4]:
env = gym.make('LunarLander-v2', new_step_api=True)
print(env.action_space.n)

4


Easy as you like. There are four discrete actions available: 1) do nothing, 2) fire left orientation engine, 3) fire main engine, 4)fire right orientation engine. 

In [5]:
NUM_ACTIONS = env.action_space.n


## What we (still) need

- [x] Environment 
- [] Memory Buffer
- [] DQN model
- [] loss function
- [] Training Loop

With the environment created, lets move forward with creating the memory buffer. We'll first create a training configuration which is used throughout this notebook. For now, pay particular attention to the memory size. This will be the maximum amount of experiences we'll hold.

In [6]:
class TrainConfig:
    MEMORY_SIZE = 10000
    BATCH_SIZE = 32
    UPDATE_PARAMS_EVERY_N_STEPS = 4
    GAMMA = 0.995
    TAU = 0.001
    E_MIN = 0.01
    E_DECAY = 0.995
    N_EPISODES = 900
    MAX_N_STEPS_PER_EPISODE = 1000
    

Here we'll create the actual memory. A simple queue in which we stare Experiences. The experiences is a single sample and when we take multiple samples from our queue, they'll form a batch used for training. The experience dataclass contains all the information to train the agent.

In [7]:
import dataclasses
from typing import NamedTuple
from collections import deque


@dataclasses.dataclass
class Experience:
    state: jnp.ndarray
    action: int
    reward: float
    next_state: jnp.ndarray
    done: bool


memory = deque(maxlen=TrainConfig.MEMORY_SIZE)

## What we (still) need

- [x] Environment 
- [x] Memory Buffer
- [] DQN model
- [] loss function
- [] Training Loop

Next it is time to start creating our model and learn more about Jax and Haiku. As always, we'll create a training state as we have to deal with a functional programming paradigm. The major difference this time is that we will not create one, but two(!) networks. This is done to make learning more stable. What we'll happen is that at every `TrainConfig.UPDATE_PARAMS_EVERY_N_STEPS` our params are updated using gradient descent. Next, we'll define the model `def network_fn()`.

In [8]:
from typing import NamedTuple

class TrainingState(NamedTuple):
    params: hk.Params
    target_params: hk.Params
    eval_params: hk.Params
    opt_state: optax.OptState

class Batch(NamedTuple):
    states: jnp.ndarray
    actions: int
    rewards: float
    next_states: jnp.ndarray
    dones: bool



def network_fn(x: jnp.ndarray) -> jnp.ndarray:
    model = hk.Sequential(
        [
            hk.Linear(64),jax.nn.relu,
            hk.Linear(64), jax.nn.relu,
            hk.Linear(NUM_ACTIONS),
        ]

    )
    return model(x)

network = hk.without_apply_rng(hk.transform(network_fn))
target_network = hk.without_apply_rng(hk.transform(network_fn))
optimiser = optax.adam(1e-3)
# Initialise network and optimiser; note we draw an input to get shapes.

In order to initialize, we need to get some data to work with. We create a helper function to sample a batch of data from memory. Next we are going to play around in the environment taking random actions and store those experiences in our memory. 

In [9]:
import jax.random as jrandom
import random
import numpy as np
keygen = jrandom.PRNGKey(0)

def get_random_batch(memory):
    batch = random.sample(memory, k=TrainConfig.BATCH_SIZE)
    return Batch(
        states=jnp.array([e.state for e in batch]),
        actions=jnp.array([e.action for e in batch]),
        rewards=jnp.array([e.reward for e in batch]),
        next_states=jnp.array([e.next_state for e in batch]),
        dones=jnp.array([e.done for e in batch]),
    )

small_memory = deque(maxlen=1000)

state = env.reset()
action = env.action_space.sample()
for _ in range(200):
    action = env.action_space.sample()
    next_state, reward, is_done, *_ = env.step(action)
    experience = Experience(state, action, reward, next_state, is_done)
    small_memory.append(experience)

batch = get_random_batch(small_memory)

batch.states.shape



(32, 8)

Now that we have a batch of experiences, we can initialize the networks. Note again that our training state contains the parameters of two networks

In [10]:

initial_params = network.init(
    jax.random.PRNGKey(seed=0), batch.states)
initial_target_params = target_network.init(jax.random.PRNGKey(seed=1), batch.states)
initial_opt_state = optimiser.init(initial_params)
train_state = TrainingState(initial_params, initial_target_params, initial_params, initial_opt_state)

  leaves, treedef = jax.tree_flatten(tree)
  return jax.tree_unflatten(treedef, leaves)


## What we (still) need

- [x] Environment 
- [x] Memory Buffer
- [x] DQN model
- [] loss function
- [] Training Loop

Aaaah the good old loss function. In the previous blogs we were dealing with supervised learning, which meant we had some labels. That's not the case anymore. We have to define our targets ourself. 

The target values are given by the Bellman equation:

$$
y = R + \gamma \max_{a'}Q(s',a';w)
$$

where $\gamma$ impacts whether the agent focuses on the long term (when close to 1) or short term rewards (close to 0), and $w$ are the weights of the neural network. By using the target network, the loss becomes:

$$
\overbrace{\underbrace{R + \gamma \max_{a'}\hat{Q}(s',a'; w^{target})}_{\rm {y~target}} - Q(s,a;w)}^{\rm {Error}}
$$

where $w^{target}$ and $w$ are the weights of the target network and $w$ of the primary network, respectively.

Finally, we update the weights gently (or in a soft fashion). This means that the weights of the target networks are updated by a weighted average of the original neural network and the target neural network.
 
$$
w^{target}\leftarrow \tau w + (1 - \tau) w^{target}
$$

where $\tau$ is normally close to 0. By using the soft update, we are ensuring that the target values, $y$, change slowly, which improves the stability of our learning algorithm.

Lets implement it in code!

In [11]:
# def loss(params: hk.Params, batch: Batch) -> jnp.ndarray:
params = train_state.params
state_actions_values = network.apply(params, batch.next_states)
max_state_actions_values = jnp.max(state_actions_values, axis=1)
targets = batch.rewards + TrainConfig.GAMMA * jnp.where(batch.dones, 0.0, max_state_actions_values)

q_values = network.apply(params, batch.states)



In [17]:

q_value_for_action_taken = q_values[jnp.arange(q_values.shape[0]), batch.actions]
q_value_for_action_taken.shape

(32,)

In [27]:
print(q_values[0])
print(batch.actions[0])
print(q_value_for_action_taken[0])
assert q_value_for_action_taken[0] == q_values[0][batch.actions[0]]

[0.2518773  0.20313816 0.13113081 0.04511986]
1
0.20313816


Have a closer look at the inline comments as they explain more closely the meaning of each matrix

In [22]:
def loss(params, target_params, batch):
    # q_values: expected future reward for taking an action when in a given state
    # i.e. shape = (32, 4)
    q_values = network.apply(params, batch.states)
    # q_values: expected future reward for actual action taken for each sample
    # i.e. shape = (32,)
    q_values_pred = q_values[jnp.arange(q_values.shape[0]), batch.actions]

    # q_values_next: expected future reward for taking an action when in the next state
    # i.e. shape = (32, 4)
    q_values_next = target_network.apply(target_params, batch.next_states)
    # q_values_next: expected future reward for taking an action when in the next state
    # i.e. shape = (32,)
    q_values_next_max = jnp.max(q_values_next, axis=1)

    # build the target
    q_value_true = batch.rewards + TrainConfig.GAMMA * jnp.where(batch.dones, 0.0, q_values_next_max)

    # compute the loss
    return jnp.mean((q_values_pred - q_value_true) ** 2)

Finally we have to update our parameters of both the target network and the original network using gradient descent. The function `optax.incremental_update()` is used to apply a soft update. 

In [23]:
@jax.jit
def update(train_state: TrainingState, batch: Batch) -> TrainingState:
    """Learning rule (stochastic gradient descent)."""
    grads = jax.grad(loss)(train_state.params, train_state.target_params, batch)
    updates, opt_state = optimiser.update(grads, train_state.opt_state)
    params = optax.apply_updates(train_state.params, updates)

    # Update target network.
    # params * TAU + (1 - TAU) * new_params
    # target_params = params * TrainConfig.TAU  + (1 - TrainConfig.TAU) * train_state.target_params
    target_params = optax.incremental_update(params, train_state.target_params, TrainConfig.TAU)
    
    # Compute avg_params, the exponential moving average of the "live" params.
    # We use this only for evaluation (cf. https://doi.org/10.1137/0330046).
    eval_params = optax.incremental_update(
        params, train_state.eval_params, step_size=0.001)
    return TrainingState(params, target_params, eval_params, opt_state)

With that covered, the final thing we need to do is create the training loop. We'll use some more tricks and tricks to make training work better, but you can skip that if you're not interested in it.

In [24]:
state = env.reset()
batch_state = jnp.array([state, state])
network.apply(train_state.params, state)
network.apply(train_state.params, batch_state)

DeviceArray([[ 0.27653083, -0.02404503,  0.06786434,  0.0538235 ],
             [ 0.27653083, -0.02404503,  0.06786434,  0.0538235 ]],            dtype=float32)

In [25]:

def update_epsilon(epsilon, train_config: TrainConfig):
    return max(train_config.E_MIN, train_config.E_DECAY*epsilon)

def exploit_or_explore(q_value: jnp.ndarray, epsilon: float = 0.1) -> int:
    """Exploit or explore according to epsilon-greedy policy."""
    if random.random() < epsilon:
        return env.action_space.sample()
    else:
        return np.array(jnp.argmax(q_value))

def is_update_params(n_steps_taken: int, train_config: TrainConfig) -> bool:
    """Update params every `update_params_every` steps."""
    return (n_steps_taken + 1) % train_config.UPDATE_PARAMS_EVERY_N_STEPS == 0


## What we (still) need

- [x] Environment 
- [x] Memory Buffer
- [x] DQN model
- [x] loss function
- [] Training Loop

Here is what will happen. We initialize some variables for keeping track of metrics and to create some cool visualizations later on.

Then we'll specify the number of episodes we'll let our agent train. Our agent will take a maximum of `train_config.MAX_N_STEPS_PER_EPISODE` or savely lands the moonlander before that number of steps is reached. The experience is added to the agents memory and if all conditions for updating the params are met (i.e. `if is_update_params()`), we'll update our parameters using gradient descent. We are confident we can safely land the moonlander if we score an average reward of 200 over more than 100 episodes. Lets see!

In [29]:
params_at_different_training_steps = {}
total_reward_history = []
moving_average_window_size = 100
epsilon = 1.0
train_config = TrainConfig()
for episode in range(train_config.N_EPISODES):
    state = env.reset()
    total_reward = 0.0

    for step in range(train_config.MAX_N_STEPS_PER_EPISODE):
        q_value = network.apply(train_state.params, state)
        action = exploit_or_explore(q_value=q_value, epsilon=epsilon)
        
        next_state, reward, is_done, *_ = env.step(action)
        experience = Experience(state, action, reward, next_state, is_done)
        memory.append(experience)
        if len(memory) < TrainConfig.MEMORY_SIZE:
            state = next_state
            total_reward += reward
            if is_done:
                break
            continue

        if is_update_params(step, train_config=train_config):
            batch = get_random_batch(memory)
            train_state = update(train_state, batch)

        state = next_state
        total_reward += reward
        if is_done:
            break

    total_reward_history.append(total_reward)
    mean_total_reward_in_window = np.mean(total_reward_history[-moving_average_window_size:])
    epsilon = update_epsilon(epsilon, train_config)


    print(f"\rEpisode {episode+1} | Total point average of the last {moving_average_window_size} episodes: {mean_total_reward_in_window:.2f}", end="")

    if (episode+1) % moving_average_window_size == 0:
        print(f"\rEpisode {episode+1} | Total point average of the last {moving_average_window_size} episodes: {mean_total_reward_in_window:.2f}")

    if (episode+1) % 100 == 0:
        network_params_name = f"params_episode_{episode + 1}"
        params_at_different_training_steps[network_params_name] = train_state.params

    # We will consider that the environment is solved if we get an
    # average of 200 points in the last 100 episodes.
    if mean_total_reward_in_window >= 200.0:
        print(f"\n\nEnvironment solved in {episode+1} episodes!")
        network_params_name = f"params_episode_final"
        params_at_different_training_steps[network_params_name] = train_state.params
print("done training")

Episode 100 | Total point average of the last 100 episodes: -115.95
Episode 191 | Total point average of the last 100 episodes: -42.267

KeyboardInterrupt: 

Lets create some cool visuals to see how our agent improves over time.

In [27]:
import base64
import imageio
import IPython

def create_video(filename, env, train_state, fps=30):
    max_steps = 300
    steps = 0
    with imageio.get_writer(filename, fps=fps) as video:
        done = False
        state = env.reset()
        frame = env.render(mode="rgb_array")
        video.append_data(frame)
        while not done:    
            steps +=1
            q_values = network.apply(train_state.params, state)
            action = jnp.argmax(q_values)
            state, _, done, *_ = env.step(np.asarray([action])[0])
            frame = env.render(mode="rgb_array")
            video.append_data(frame)
            if done and steps < max_steps:
                while steps < max_steps:
                    video.append_data(frame)
                    steps += 1
            if steps >= max_steps:
                break



In [28]:
import os
try:
    os.makedirs('./videos')
except FileExistsError:
    pass
for episode_name, episode_params in params_at_different_training_steps.items():
    filename = f"./videos/lunar_{episode_name}.gif"
    episode_train_state = TrainingState(episode_params, episode_params, episode_params, initial_opt_state)
    create_video(filename, env, episode_train_state)

See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(



| | | |
|:-------------------------:|:-------------------------:|:-------------------------:|
|<img width="1604" alt="screen shot 2017-08-07 at 12 18 15 pm" src="../assets/videos/lunar_params_episode_100.gif">  100 Episodes |  <img width="1604" alt="screen shot 2017-08-07 at 12 18 15 pm" src="../assets/videos/lunar_params_episode_200.gif"> 200 Episodes|<img width="1604" alt="screen shot 2017-08-07 at 12 18 15 pm" src="../assets/videos/lunar_params_episode_300.gif"> 300 Episodes |
|<img width="1604" alt="screen shot 2017-08-07 at 12 18 15 pm" src="../assets/videos/lunar_params_episode_400.gif">  400 Episodes |  <img width="1604" alt="screen shot 2017-08-07 at 12 18 15 pm" src="../assets/videos/lunar_params_episode_500.gif"> 500 Episodes|<img width="1604" alt="screen shot 2017-08-07 at 12 18 15 pm" src="../assets/videos/lunar_params_episode_600.gif"> 600 Episodes |
|<img width="1604" alt="screen shot 2017-08-07 at 12 18 15 pm" src="../assets/videos/lunar_params_episode_700.gif">  700 Episodes |  <img width="1604" alt="screen shot 2017-08-07 at 12 18 15 pm" src="../assets/videos/lunar_params_episode_800.gif"> 800 Episodes|<img width="1604" alt="screen shot 2017-08-07 at 12 18 15 pm" src="../assets/videos/lunar_params_episode_900.gif"> 900 Episodes |


## Conclusion

Well done if you made it this far! You just landed a man (uhhh) moonlander on the moon. This was the final blogpost in this series on getting started with Jax. In this blog we learnt a bit about Reinforcement Learning and how to use Jax to train an RL-agent. Pretty awesome! Stay tuned for more advanced tutorials and how to run your deep learning models in production!

Finally, feel free to decide for yourself if you want to, but if you liked this blogpost (or the series in general), I'd greatly appreciate if you could like it on Medium and/or GitHub to help others more easily find this content as well.

## Connect, learn and contribute to help yourself and others land a job in the AI space

Looking for a way to contribute or learn more about AI/ML, connect with me on medium:
- LinkedIn: [https://www.linkedin.com/in/stefruinard/](https://www.linkedin.com/in/stefruinard/)
- Medium: [https://medium.com/@stefruinard](https://medium.com/@stefruinard)
- GitHub: [https://github.com/Sruinard](https://github.com/Sruinard)

## Contributors:
###### Submit a Pull Request or reach out on LinkedIn and become a recognized contributor :)