# Intro to Reinforcement Learning


[Need to update this picture]

<a href="https://colab.research.google.com/github/deep-learning-indaba/indaba-pracs-2022/blob/feat/intro-rl-section-1-updates/intro_to_rl.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a> [Change colab link to point to prac.]

© Deep Learning Indaba 2022. Apache License 2.0.

**Authors:**
Claude Formanek, Kale-ab Tessera

**Introduction:** 

In this tutorial, we will be learning about Reinforcement Learning, a type of Machine Learning where an agent learns to chose actions in an environment that lead to maximal reward in the long run. RL has seen tremendous success on a wide range of challenging problems such as learning to play complex video games like Atari, StarCraft II and Dota II. In this introductory tutorial we will solve the classic CartPole environment, where an agent must learn to balance a pole on a cart, using several different RL approaches. Along the way you will be introduces to some of the most important concepts and terminology in RL.

**Topics:** 
* Reinforcement Learning
* Policy Gradients
* Q-Learning

**Level:** 

Beginner

**Aims/Learning Objectives:**

* Understand the basic theory behind RL
* Implement a simple policy gradient RL algorithm
* Implement a simple Q-learning algorithm

**Prerequisites:**

* Some familiarity with [Jax](https://github.com/google/jax).
* Neural network basics.

**Outline:** 

* Section 1: Key Concepts in Reinforcement Learning
* Section 2: Random Policy Search
* Section 3: Policy Gradient
* Section 4: Deep Q-Learning

**Before you start:**

For this practical, you will need to use a GPU to speed up training. To do this, go to the "Runtime" menu in Colab, select "Change runtime type" and then in the popup menu, choose "GPU" in the "Hardware accelerator" box.


## Setup

In [None]:
# @title Install required packages (Run Cell)
%%capture
!pip install jaxlib
!pip install jax
!pip install git+https://github.com/deepmind/dm-haiku
!pip install gym
!pip install gym[box2d]
!pip install rlax
!pip install optax
!pip install matplotlib

In [None]:
# @title Import required packages (Run Cell)
%%capture
import random
import collections # useful data structures
import numpy as np
import gym # reinforcement learning environments
import jax
import jax.numpy as jnp # jax numpy
import haiku as hk # jax neural network library
import optax # jax optimizer library
import rlax # jax reinforcement learning library
import matplotlib.pyplot as plt # graph plotting library

## Section 1: Key Concepts in Reinforcement Learning

Reinforcement Learning (RL) is a subfield of Machine Learning (ML). Unlike fields like supervised learning, where we give examples of expected behaviour to our models, RL focuses on *goal-orientated* learning from interactions, through trial-and-error. RL algorithms learn what to do (i.e. which optimal actions to take) in an environment to maximise some reward signal. In settings like a video game, the reward signal could be the score of the game, i.e., RL algorithms will try to maximise the score in the game by chosing the best actions.  

<center>
<img src="https://miro.medium.com/max/1400/1*Ews7HaMiSn2l8r70eeIszQ.png" width="60%" />
</center>


More precisely, in RL we have an **agent** which perceives an **observation** $o_t$ of the current state $s_t$ of the **environment** and must choose an **action** $a_t$ to take. The environment then transitions to a new state $s_{t+1}$ in response to the agent's action and also gives the agent a scalar reward $r_t$ to indicate how good or bad the chosen action was given the environment's state. The goal in RL is for the agent to maximise the amount of reward it receives from the environment over time. The subscript $t$ is used to indicate the timestep number, i.e., $s_0$ is the state of the environment at the initial timestep, and $a_{99}$ is the agent's action at the $99th$ timestep. 

[Image Source](https://towardsdatascience.com/multi-agent-deep-reinforcement-learning-in-15-lines-of-code-using-pettingzoo-e0b963c0820b)

### Environment - OpenAI Gym
As mentioned above, an environment receives an action $a_t$ and returns reward $r_t$ and observation $o_t$.

OpenAI has provided a Python package called Gym that includes implementations of popular environments and a simple interface for an RL agent to interact with. To use a supported gym [environment](https://www.gymlibrary.ml/), all you need to do is pass the name of the environment to the function `gym.make(<environment_name>)`. 

In this tutorial, we will be using a simple environment called CartPole. In **CartPole** the task is for the agent to learn to balance the pole for as long as possible by moving the cart *left* or *right*.

<img src="https://miro.medium.com/max/600/1*v8KcdjfVGf39yvTpXDTCGQ.gif" width="30%" />

In [None]:
# Create the environment
env = gym.make("CartPole-v1")

### States and Observations - $s_t$ and $o_t$

In RL, an agent perceives an observation of the environment's state. In some settings, the observation may include all the information underlying the environment's state. Such an environment is called **fully observed**. In other settings, the agent may only receive partial information about the environment's state in its observation. Such an environment is called **partially observed**. 

For the rest of this tutorial, we will assume the environment is fully observed and so we will use state $s_t$ and observation $o_t$ interchangeably. In Gym we get the initial observation from the environment by calling the function `env.reset()`: 

In [None]:
# Reset the environment
s_0 = env.reset()
print("Initial State::", s_0)

In CartPole, the state of the environment is represented by four numbers; *angular position of the pole, angular velocity of the pole, position of the cart, velocity of the cart*. 

### Actions - $a_t$

In RL actions are usually either **discrete** or **continuous**. Continuous actions are given by a vector of real numbers. Discrete actions are given by an integer value. In environments where we can count out the finite set of actions we usually use discrete actions. 

In CartPole there are only two actions; *left and right*. As such, the actions can be represented by integers $0$ and $1$. In gym we can easily get the list of possible actions as follows:

In [None]:
# Get action space - e.g. discrete or continuous
print(f"Environment action space: {env.action_space}")

# Get num actions
num_actions = env.action_space.n
print(f"Number of actions: {num_actions}")

### The Agent's Policy - $\pi$

In RL the agent chooses actions based on the observations it receives. We can think of the agent's action selection process as a function that takes an observation as input and returns an action as output. In RL we usually call this function an agent's **policy** and denote it $\pi(s_t)$. 

Our policies can be **deterministic** ($s_t$ is mapped to a single action $a_t$), as follows: 

<center>
$\pi(s_t)=a_t$.
</center>

In other cases, our policy could rather be **stochastic**, where it returns a distribution and actions are sampled from this distribution. We denote stochastic policies as follows:
<center>
$a_t\sim\pi(\cdot\ |\ s_t)$
</center>

, where the symbol $\cdot$ is simply a shorthand for *all actions* and "~" means "*sampled from*".

Policies are parameterized by weights $\theta$. We sometimes write this as a subscript on the policy symbol to highlight the connection as follows - $\pi_{\theta}$, but for convenience we will simply refer to policies as $\pi$.

**Exercise 1:** If Bob has a deterministic policy $\pi$ and the first time Bob uses his policy on some observation $o_t$ he chooses action $0$. What will Bob's chosen action be if he uses the same policy a second time on the exact same observation $o_t$ (same timestep $t$). Chose from the options below and assume there are only two possible actions:

In [None]:
#@title Exercise 1
selection = "Bob will chose action 1." #@param ["Bob will chose action 0 again.", "Bob will chose action 1.", "You can't say."]
print(f"You selected: {selection}")

In [None]:
#@title Check Exercise 1
correct_answer = "Bob will chose action 0 again."
assert selection == correct_answer, "Incorrect answer, hint ..."

""" 
Exercise 1: Since the policy is deterministic, Bob will always choose the same 
action given the same observation.
"""

print("Nice, you got the correct answer!")

**Exercise 2:**  If Alice has a stochastic policy $\pi$ and the first time Alice uses her policy on some observation $o_t$ she chooses action $0$. What will Alice's chosen action be if she uses her policy a second time on the exact same observation $o_t$. Chose from the options below and assume there are only two possible actions:

In [None]:
#@title Exercise 2
selection = "Alice will chose action 0 again." #@param ["Alice will chose action 0 again.", "Alice will chose action 1.", "You can't say."]
print(f"You selected: {selection}")

In [None]:
#@title Check Exercise 2
correct_answer = "You can't say."
assert selection == correct_answer, "Incorrect answer, hint ..."

""" 
Exercise 2: Since the policy is stochastic, the result from Alice's policy will be
random. So you can't say.
"""

print("Nice, you got the correct answer!")

As an exercise we will implement a stochastic policy as well as a deterministic policy for CartPole. 

*Note, in this section we are not learning $\pi$'s weights $\theta$, we are just using a passed in value for the weights.*

**Exercise 3:** Complete the following functions:
- `linear_policy` : computes the linear combination of $\pi$'s weights $\theta$ and observation $o_t$. This is a linear function approximator. 
- `choose_action` : [Discretize](https://en.wikipedia.org/wiki/Discretization) the result from the linear policy as follows:
    - if the `result is less than or equal to zero` - return a `0`
    - if the `result is greater than zero` - return a `1`

Is this a deterministic or stochastic policy?

**Useful methods:** 
* [JAX Numpy dot product](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.dot.html).
* [JAX Numpy where](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html) or [Jax select](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html#jax.lax.select).


**Notes:**
*   We already imported `jax numpy` as `jnp`.

In [None]:
# params refer to weights of our policy e.g. [1,-2,2,-1]
def linear_policy(params,obs):
  # YOUR CODE
  result = ...
  # END YOUR CODE
  return result 

# params refer to weights of our policy e.g. [1,-2,2,-1]
def choose_action(params, obs):
  result = linear_policy(params,obs)
  # YOUR CODE
  action = ...

  # END YOUR CODE
  return action

In [None]:
# @title Check correctness of implementation (Run me) 

def check_deterministic_linear_policy(linear_policy,choose_action):
  fixed_obs = jnp.array([1,1,2,4])
  
  # check case1 - negative dot product.
  # weights
  params = jnp.array([1,-2,2,-1])

  assert linear_policy(params,fixed_obs) == -1, "Incorrect answer, your linear policy is incorrect."
  assert choose_action(params,fixed_obs) == 0,  "Incorrect answer, your choose action function is incorrect."

  # Check deterministic policy
  for i in range(100):
    assert choose_action(params,fixed_obs) == 0,  "Incorrect answer, your choose action function is not deterministic."


  # check case2 - positive dot product
  params = jnp.array([1,2,2,1])

  assert linear_policy(params,fixed_obs) == 11, "Incorrect answer, your linear policy is incorrect."
  assert choose_action(params,fixed_obs) == 1,  "Incorrect answer, your choose action function is incorrect."

  # Check deterministic policy
  for i in range(100):
    assert choose_action(params,fixed_obs) == 1,  "Incorrect answer, your choose action function is not deterministic."

  # check case3 - 0 dot product
  params = jnp.array([0,0,0,0])

  assert linear_policy(params,fixed_obs) == 0, "Incorrect answer, your linear policy is incorrect."
  assert choose_action(params,fixed_obs) == 0,  "Incorrect answer, your choose action function is incorrect."

  # Check deterministic policy
  for i in range(100):
    assert choose_action(params,fixed_obs) == 0,  "Incorrect answer, your choose action function is not deterministic."

  print("Your function is correct!")
check_deterministic_linear_policy(linear_policy,choose_action)

In [None]:
# @title  Exercise 3 solution - Answer to code task (Try not to peek until you've given it a good try!')
# params refer to weights e.g. [1,-2,2,-1]
def linear_policy(params,obs):
  result = jnp.dot(params,obs)
  return result 

def choose_action(params, obs):
  result = linear_policy(params,obs)
  action = jnp.where(result <= 0,0,1)
  return action

check_deterministic_linear_policy(linear_policy,choose_action)

**Exercise 4:** Following on exercise 3, let's implement a different way of choosing actions. For this exercise, implement:
- `choose_action` : [Discretize](https://en.wikipedia.org/wiki/Discretization) the result from the linear policy as follows:
    - if the `result is less than or equal to zero` :  
      - return a `0` 20% of the time.
      - return a `1` 80% of the time.
    - if the `result is greater than zero` : 
      - return a `0` 100% of the time. 

Is this a deterministic or stochastic policy?

**Useful methods:** 
*   [JAX random choice](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.choice.html) or [JAX Numpy Uniform](https://numpy.org/doc/stable/reference/random/generated/numpy.random.uniform.html).
* [JAX Numpy Where](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.where.html) or [JAX Conditional](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.cond.html).

In [None]:
# function from exercise 3
def linear_policy(params,obs):
  result = jnp.dot(params,obs)
  return result 

# key refers to our jax random key e.g. jax.random.PRNGKey(42)
def choose_action(key, params, obs):
  result = linear_policy(params,obs)
  # YOUR CODE
  action = ...

  # END YOUR CODE
  return action

In [None]:
# @title Check correctness of implementation (Run me) 

def check_stochastic_linear_policy(choose_action):
  key = jax.random.PRNGKey(42)
  fixed_obs = jnp.array([1,1,2,4])
  
  # check case1
  # weights
  params = jnp.array([1,-2,2,-1])

  assert linear_policy(params,fixed_obs) == -1, "Incorrect answer, your linear policy is incorrect."

  # Check stochastic policy
  actions = []
  for i in range(100):
    actions.append(choose_action(key,params,fixed_obs))
    new_key, _ = jax.random.split(key)
    key = new_key
  
  sum_actions = jnp.sum(jnp.array(actions))
  # This is around ~ 80%, we know the exact value since Jax handles PRNG consistently!
  assert sum_actions == 84,  "Incorrect answer, your choose action function is incorrect."

  # check case2
  params = jnp.array([1,2,2,1])

  assert linear_policy(params,fixed_obs) == 11, "Incorrect answer, your linear policy is incorrect."

  # Check stochastic policy
  actions = []
  for i in range(100):
    actions.append(choose_action(key,params,fixed_obs))
    new_key, _ = jax.random.split(key)
    key = new_key

  sum_actions = jnp.sum(jnp.array(actions))
  assert sum_actions == 0,  "Incorrect answer, your choose action function is incorrect."

check_stochastic_linear_policy(choose_action)

In [None]:
# @title  Exercise 4 solution - Answer to code task (Try not to peek until you've given it a good try!')
# function from exercise 3
def linear_policy(params,obs):
  result = jnp.dot(params,obs)
  return result 

# key refers to our jax random key e.g. jax.random.PRNGKey(42)
def choose_action(key, params, obs):
  result = linear_policy(params,obs)
  action = jnp.where(result <= 0,jax.random.choice(key,a=jnp.array([0,1]),p=jnp.array([0.2,0.8])),0)
  return action

# Inefficient example using all if statements and jax.random.uniform
# def choose_action(key, params, obs):
#   result = linear_policy(params,obs)
#   r = jax.random.uniform(key)
#   if result <= 0:
#     if r <= 0.8:
#       action = 1
#     else:  
#       action = 0 
#   else:
#     action = 0
#   return action

check_stochastic_linear_policy(choose_action)

### The Environment Transition Function - $P$

Now that we have a policy we can pass actions from the agent to the environment. The environment will then transition into a new state in response to the agent's action. 

In RL we model this process by using a **state transition function** $P$ which takes the current state $s_t$ and an action $a_t$ as input and returns the next state $s_{t+1}$ as output. Like with policies, the state transition function can either be *deterministic*: 
<center>
 $s_{t+1}=P(s_t, a_t)$
</center> 

or it can be *stochastic*: 
<center>
 $s_{t+1}\sim P(\cdot\ |\ s_t, a_t)$
</center>





In gym we can pass actions to the environment by calling the `env.step(<action>)` function. The function will then return four values:
- the **next observation**
- the **reward** for the action taken
- a boolean flag to indicate if the game is **done** 
- some **extra** information.

In [None]:
# Get the initial obs by resetting the env
initial_obs = env.reset()

# Randomly sample actions from env
action = env.action_space.sample()

# Step the environment
next_obs, reward, done, info = env.step(action)

print("Observation:", initial_obs)
print("Action:", action)
print("Next observation:", next_obs)
print("Reward:", reward)
print("Game is done:", done)

### Episode Return - $R_t$

In RL we usually break an agent's interactions with the environment up into **episodes** (sequence of interactions with an environment, usually ending in a terminal state). 

The sum of all rewards collected during an episode is what we call the episode's **return** - $R_t$. The simplest formulation is the **finite-horizon, undiscounted return**, which can be formulated as follows:

<center>
$R_t=\sum_{t=0}^Tr_t$
</center>

, where $r_t$ is our reward at time $t$ and $T$ is our terminal state. This return is calculated over a fixed window of time and assumes that terminal state $T$ is always reached and every reward $r_t$ is weighted equally.


Generally, in practice, we tend to use the **infinite horizon, discounted rewards** (also referred to as the *expected discounted return*), which sums all rewards *ever obtained over time*, discounted by how far off in the future they are obtained. This is formulated as follows:

<center>
$R_t =\sum_{t=0}^{\infty} \gamma^{t} r_{t}$
</center>

, where $\gamma \in (0,1) $ is our discount rate. 

The goal in RL is for the agent to chose actions which maximise this expected future return $R_t$.  


**Group task**: Discuss with your neighbour why we would want to use a discount factor? 

### Agent-environment Loop
Now that we know what a policy is and we know how to step the environment, let's close the agent-environment loop.

**Exercise 5:** Write a function that runs one episode of CartPole by sequentially choosing actions and stepping the environment. You should use the stochastic policy we defined earlier to chose actions (i.e. the `choose_action` function). The function should keep track of the reward received and output the return at the end of the episode. For simplicity, we will use the **finite-horizon, undiscounted return**.

In CartPole the agent receives a reward of `1` for every timestep the pole is still upright. If the pole falls over, the game is over and the agent receives a reward of `0`.

In [None]:
def run_episode(env):
  episode_return = 0

  ## YOUR CODE
  # initial obs
  obs = ...
  done = False
  # policy params
  params = jnp.array([1,-2,2,-1])
  key = jax.random.PRNGKey(42)

  # while loop until episode is done
  while not done:
    # HINT: You might need to the convert the action from your policy to a np.array
    action = ...
    # HINT: Step in your environment
    next_obs, reward, done, info = ...
    # HINT: Update observations
    obs = ...
    episode_return = ...

  return episode_return

In [None]:
# @title Check correctness of implementation (Run me) 

env.seed(42)
assert run_episode(env) == 8

In [None]:
#@title Exercise 5 solution
def run_episode(env):
  episode_return = 0
  obs = env.reset()
  done = False
  params = jnp.array([1,-2,2,-1])
  key = jax.random.PRNGKey(42)

  while not done:
    action = np.array(choose_action(key, params, obs))
    next_obs, reward, done, info = env.step(action)
    obs = next_obs
    episode_return += reward

  return episode_return

env.seed(42)
assert run_episode(env) == 8

print("Episode return:", run_episode(env))

In CartPole, the environment is considered solved when the agent can reliably achieve an episode return of 500. As you can see, our current policy is nowhere near optimal yet. Let's learn a way to find an optimal policy.

One way we can find an optimal policy is by randomly searching for it. Obviously in a complex environment finding an optimal policy by randomly trying different strategies could take forever. But CartPole is a sufficiently simple environment that it might just work.

Before we implement Random Policy Search, let's take a look at the following environment loop function that we implemented for you. Its the environmentloop we will be using for the rest of the notebook. We will use a [NamedTuple](https://www.geeksforgeeks.org/namedtuple-in-python/) to bundle `obs`, `action`, `reward`, `next_obs` and the done flag into a **transition** object.

In [None]:
# Named tuple to store transition
Transition = collections.namedtuple("Transition", ["obs", "action", "reward", "next_obs", "done"])

# TEST
transition = Transition(
    obs=[1,2,-1,2],
    action=0,
    reward=10,
    next_obs=[1,2,2,1],
    done=True
)

print("Transition obs:", transition.obs)
print("Transition action:", transition.action)
print("Transition reward:", transition.reward)
print("Transition next obs:", transition.next_obs)
print("Transition done:", transition.done)

Below is the environment loop function we have implemented for you. We recommend reading through it and trying to understand it, but if anything is unclear, don't worry about it. It should all make more sense as we work through this tutorial.

In [None]:
# Environment loop
def run_environment_loop(rng, env, agent_params, agent_select_action_func, 
    agent_actor_state=None, agent_learn_func=None, agent_learner_state=None, 
    agent_memory=None, num_episodes=1000, evaluator_period=100, 
    evaluation_episodes=32, learn_steps_per_episode=1):
    """
    This function runs several episodes in an environment and periodically does 
    some agent learning and evaluation.
    
    Args:
        rng: a random number generator. This is for jax.
        env: a gym environment.
        agent_params: an object to store parameters that the agent uses.
        agent_select_func: a function that does action selection for the agent.
        agent_actor_state (optional): an object that stores the internal state 
            of the agents action selection function.
        agent_learn_func (optional): a function that does some learning for the 
            agent by updating the agent parameters.
        agent_learn_state (optional): an object that stores the internal state 
            of the agent learn function.
        agent_memory (optional): an object for storing an retrieving historical 
            experience.
        num_episodes: how many episodes to run.
        evaluator_period: how often to run evaluation.
        evaluation_episodes: how many evaluation episodes to run.

    Returns:
        episode_returns: list of all the episode returns.
        evaluator_episode_returns: list of all the evaluator episode returns.
    """
    episode_returns = [] # List to store history of episode returns.
    evaluator_episode_returns = [] # List to store history of evaluator returns.
    for episode in range(num_episodes):

        # Reset environment.
        obs = env.reset()
        episode_return = 0
        done = False

        while not done:

            # Agent select action.
            action, agent_actor_state = agent_select_action_func(
                                            next(rng), 
                                            agent_params, 
                                            agent_actor_state, 
                                            np.array(obs)
                                        )

            # Step environment.
            next_obs, reward, done, _ = env.step(int(action))

            # Pack into transition.
            transition = Transition(obs, action, reward, next_obs, done)

            # Add transition to memory.
            if agent_memory: # check if agent has memory
              agent_memory.push(transition)

            # Add reward to episode return.
            episode_return += reward

            # Set obs to next obs before next environment step. CRITICAL!!!
            obs = next_obs

        episode_returns.append(episode_return)

        # At the end of every episode we do a learn step.
        if agent_memory and agent_memory.is_ready(): # Make sure memory is ready

            for _ in range(learn_steps_per_episode):
                # First sample memory and then pass the result to the learn function
                memory = agent_memory.sample()
                agent_params, agent_learner_state = agent_learn_func(
                                                        next(rng), 
                                                        agent_params, 
                                                        agent_learner_state, 
                                                        memory
                                                    )

        if (episode % evaluator_period) == 0: # Do evaluation

            evaluator_episode_return = 0
            for eval_episode in range(evaluation_episodes):
                obs = env.reset()
                done = False
                while not done:
                    action, _ = agent_select_action_func(
                                    next(rng), 
                                    agent_params, 
                                    agent_actor_state, 
                                    np.array(obs), 
                                    evaluation=True
                                )

                    obs, reward, done, _ = env.step(int(action))

                    evaluator_episode_return += reward

            evaluator_episode_return /= evaluation_episodes

            evaluator_episode_returns.append(evaluator_episode_return)

            logs = [
                    f"Episode: {episode}",
                    f"Episode Return: {episode_return}",
                    f"Average Episode Return: {np.mean(episode_returns[-20:])}",
                    f"Evaluator Episode Return: {evaluator_episode_return}"
            ]

            print(*logs, sep="\t") # Print the logs

    return episode_returns, evaluator_episode_returns

Before we can test this environment loop function with the policy function we implemented earlier, we will need to modify it so that its interface matches the way our environment loop expects it. The `choose_action` function should take a random seed in the first argument position, then parameters, then an actor's internal state (more on this later), then the observation and finally an evaluation boolean flag that indicates if the function is being called during the evaluation loop or not (more on this later). The function should return the chosen action and the next state of the actor.

Let's adapt our solution from **exercise 3**. 

In [None]:
def choose_action(key, params, actor_state, obs, evaluation=False):
  del key, evaluation # not used yet

  result = linear_policy(params,obs)
  action = jnp.where(result <= 0,0,1)
  return action, actor_state

Let's [JIT](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) our `choose_actions` function to speed it up. If you are unfamiliar with JIT or other JAX basics, you can review the *Introduction to ML using Jax* practical. 



In [None]:
choose_action_jit = jax.jit(choose_action)

**A quick aside about random number:** when using Jax you will sometimes need to pass a random key to your functions. We will instantiate a Random Number Generator and use that to continuously generate new random keys. 

In [None]:
# Jax random number generator
rng = hk.PRNGSequence(jax.random.PRNGKey(0))

We can now test our `select_action` function in the `run_environment_loop` function.

In [None]:
# Some arbitrary parameters
params = np.array([1,1,-1,-1], "float32")

episode_returns, evaluator_returns = run_environment_loop(
                                        rng, # random number generator
                                        env, # environment
                                        params, # parameters used by the action selector
                                        choose_action_jit,
                                        num_episodes=1001,
                                      )

print("Average episode returns:", np.mean(episode_returns))

### Agent memory

In many RL algorithms, the agent uses a kind of memory to store some of the experiences it had in the environment. The interface we will use for the agent's memory is very simple. It will have a function `memory.push(<transition>)` that adds some information about the transition to the memory, a function `memory.is_ready()` to check if the memory is ready to do some learning, and finally a function `memory.sample()` that returns some information that the agent learn function can use to do learning.

#### Average Episode Return Memory
We have built a simple agent memory module for you below. It stores the `epsisode_returns` of the last 20 episodes. Read through our implementation below and see if you can understand it. The `memory.sample()` method returns the average episode return over the last 20 episodes.

In [None]:
# A NamedTuple to store the average episode return of the last 20 runs
AverageEpisodeReturnMemory = collections.namedtuple(
                                "AverageEpisodeReturnMemory", 
                                ["average_episode_return"]
                            )

class AverageEpisodeReturnBuffer:

    def __init__(self, num_episodes_to_store=20):
        """
        This class implements an agent memory that stores the average episode 
        return over the last 20 episodes.
        """
        self.num_episodes_to_store = num_episodes_to_store
        self.episode_return_buffer = []
        self.current_episode_return = 0

    def push(self, transition):
        self.current_episode_return += transition.reward

        if transition.done: # If the episode is done
            # Add episode return to buffer
            self.episode_return_buffer.append(self.current_episode_return)

            # Reset episode return
            self.current_episode_return = 0


    def is_ready(self):
        return len(self.episode_return_buffer) == self.num_episodes_to_store

    def sample(self):
        average_episode_return = np.mean(self.episode_return_buffer)

        # Clear episode return buffer
        self.episode_return_buffer = []

        return AverageEpisodeReturnMemory(average_episode_return)

## Section 2: Random Policy Search (RPS)
In Section 1, we used predefined parameters for our policy, that is to say we didn't learn $\pi$'s weights $\theta$, we just simply set them ( `params = jnp.array([1,-2,2,-1])`, `linear_policy(params,...`). There are various algorithms to find and improve our policy's weights.  

One such algorithm is Random Policy Search (RPS), which is a method that will randomly try different policies and keep track of the best policy it has found so far. We will say that policy $A$ is better than policy $B$ if the average episode return policy $A$ achieved over the last 20 episodes is greater than that of policy $B$. We will need to modify the way we store the agent's parameters so that we always have access to the latest parameters as well as the best parameters.

In [None]:
# Parameter container for Random Policy Search
RandomPolicySearchParams = collections.namedtuple("RandomPolicySearchParams", ["current", "best"])

# TEST: store two different sets of parameters
current_params = np.array([1,1,-1,-1])
best_params = np.array([0,0,0,0])
rps_params = RandomPolicySearchParams(current_params, best_params)

# How to access the best or current params.
print(f"Best params: {rps_params.best}")
print(f"Current params: {rps_params.current}")

We will implement the following:
  - RPS Select Action - How we choose actions given a policy.
  - RPS Learn - How we update and improve our policy.

### RPS Select Action Function
Now let's once again modify our `choose_action` function such that it uses the best parameters when `evaluation==True` and uses the current parameters when `evaluation==False`.

> We can still use the `linear_policy` function to calculate the forward pass of our policy and then discretize the result from the linear policy as follows:
  - if the `result is less than or equal to zero` - return a `0`
  - if the `result is greater than zero` - return a `1`

**Exercise 7:** Implement the `random_policy_search_choose_action` function as described above. Make sure you use Jax so that we can jit the function.

In [None]:
def linear_policy(params,obs):
  result = jnp.dot(params,obs)
  return result 

def random_policy_search_choose_action(
    key, 
    params, 
    actor_state, 
    obs, 
    evaluation=False
):

  # YOUR CODE

  # HINT: best_action = ... (two lines)
  # 2 steps - (1) Forward pass through linear policy, (2) then Discretize

  # HINT: current_action = ... (two lines)
  # 2 steps - (1) Forward pass through linear policy, (2) then Discretize

  # HINT: action = best_action if evaluation else current_action (one line)

  # END YOUR CODE

  return action, actor_state

In [None]:
# @title Check correctness of implementation (Run me) 

def check_random_policy_search_choose_action(choose_action):
  key = None # not used
  actor_state = None # not used

  # obs
  obs = jnp.array([1,1,2,4])

  # eval=False checks
  # Parameters
  evaluation=False
  current_params = jnp.array([-1,-1,-1,-1])
  best_params = jnp.array([0,0,0,0])
  # check case1 - negative dot product.
  current_params = jnp.array([1,-2,2,-1])
  rps_params = RandomPolicySearchParams(current_params, best_params)
  action, actor_state = choose_action(key,rps_params,actor_state,obs,evaluation)
  assert action == 0,  "Incorrect answer, your choose action function is incorrect."

  # check case2 - positive dot product
  current_params = jnp.array([1,2,2,1])
  rps_params = RandomPolicySearchParams(current_params, best_params)
  action, actor_state = choose_action(key,rps_params,actor_state,obs,evaluation)
  assert action == 1,  "Incorrect answer, your choose action function is incorrect."

  # check case3 - 0 dot product
  current_params = jnp.array([0,0,0,0])
  rps_params = RandomPolicySearchParams(current_params, best_params)
  action, actor_state = choose_action(key,rps_params,actor_state,obs,evaluation)
  assert action == 0,  "Incorrect answer, your choose action function is incorrect."

  # eval=True checks
  evaluation=True
  current_params = jnp.array([-1,-1,-1,-1])
  best_params = jnp.array([0,0,0,0])
  # check case1 - negative dot product.
  best_params = jnp.array([1,-2,2,-1])
  rps_params = RandomPolicySearchParams(current_params, best_params)
  action, actor_state = choose_action(key,rps_params,actor_state,obs,evaluation)
  assert action == 0,  "Incorrect answer, your choose action function is incorrect."

  # check case2 - positive dot product
  best_params = jnp.array([1,2,2,1])
  rps_params = RandomPolicySearchParams(current_params, best_params)
  action, actor_state = choose_action(key,rps_params,actor_state,obs,evaluation)
  assert action == 1,  "Incorrect answer, your choose action function is incorrect."

  # check case3 - 0 dot product
  best_params = jnp.array([0,0,0,0])
  rps_params = RandomPolicySearchParams(current_params, best_params)
  action, actor_state = choose_action(key,rps_params,actor_state,obs,evaluation)
  assert action == 0,  "Incorrect answer, your choose action function is incorrect."

  print("Your function is correct!")

random_policy_search_choose_action_jit = jax.jit(random_policy_search_choose_action) # jit the function
check_random_policy_search_choose_action(random_policy_search_choose_action_jit)

In [None]:
#@title Exercise 7 solution
def linear_policy(params,obs):
  result = jnp.dot(params,obs)
  return result 

def random_policy_search_choose_action(
    key, 
    params, 
    actor_state, 
    obs, 
    evaluation=False
):

  dot_product = linear_policy(params.best,obs)
  best_action = jnp.where(dot_product <= 0,0,1)

  dot_product = linear_policy(params.current,obs)
  current_action = jnp.where(dot_product <= 0,0,1)

  action = jnp.where(evaluation,best_action,current_action)
  return action, actor_state

random_policy_search_choose_action_jit = jax.jit(random_policy_search_choose_action) # jit the function
check_random_policy_search_choose_action(random_policy_search_choose_action_jit)

### RPS learn function
Now we need to implement a `learn` function for our Random Policy Search agent. The `learn` function is quite simple. All we need to do is check if the current weights are better than the best weights. If they are better, then set the current weights to be the new best weights and randomly sample a new set of current weights. 

Let's assume that our learn function receives a memory from the `AverageEpisodeReturnMemory` we implemented earlier. We can use this to compare the current weights to the best weights. We will need to keep track of the best average episode return for the learn function. For that, we can use the `learn_state` argument.

In [None]:
# A NamedTuple to store the best average episode return so far
LearnerState = collections.namedtuple(
                                      "LearnerState", 
                                      ["count", "best_average_episode_return"]
                                      )

**Exercise 8:** Write a function to randomly sample new weights using jax. The weights should be sampled from the interval `[-2,2]`.

**Useful functions:** 
*   [Jax random uniform sample](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.uniform.html#jax.random.uniform)

In [None]:
def get_new_random_weights(random_key, old_weights,minval=-2.0,maxval=2.0):
    new_weights_shape = old_weights.shape
    new_weights_dtype = old_weights.dtype

    # YOUR CODE

    # you should overwrite this
    new_weights = ...

    # END YOUR CODE
    return new_weights

In [None]:
# @title Check correctness of implementation (Run me) 

def check_random_policy_search_choose_action(get_new_random_weights):
  old_weights = jnp.array([1,1,1,1], "float32")
  random_key = jax.random.PRNGKey(42)

  # Case 1
  new_weights = get_new_random_weights_jit(random_key, old_weights,minval=-2.0,maxval=2.0)
  assert jnp.array_equal(new_weights,jnp.array([ 0.29657745,1.4265499, -1.7621555, -1.7505779 ]))
  
  # Case 2
  new_weights = get_new_random_weights_jit(random_key, old_weights,minval=-0.1,maxval=0.1)
  assert jnp.allclose(new_weights,jnp.array([0.01482888,0.0713275,-0.08810778,-0.0875289]))
  
  print("Function is correct!")

get_new_random_weights_jit = jax.jit(get_new_random_weights) # jit the function
check_random_policy_search_choose_action(get_new_random_weights)

In [None]:
#@title Exercise 8 solution

def get_new_random_weights(random_key, old_weights,minval=-2.0,maxval=2.0):
    new_weights_shape = old_weights.shape
    new_weights_dtype = old_weights.dtype
    # Sample new weights
    new_weights = jax.random.uniform(random_key,new_weights_shape,new_weights_dtype,minval=minval,
                      maxval=maxval)
    return new_weights

get_new_random_weights_jit = jax.jit(get_new_random_weights) # jit the function
check_random_policy_search_choose_action(get_new_random_weights)

Now let's implement the Random Policy Search function.

**Exercise 9:** Use the description of the Random Policy Search learn function at the top of this section to complete the function below. Try to use jax (remember `jnp.where()`).

In [None]:
def random_policy_search_learn(key, params, learner_state, memory):
    best_weights = params.best 
    current_weights = params.current

    current_episode_return = memory.average_episode_return
    best_average_episode_return = learner_state.best_average_episode_return


    # YOUR CODE

    # HINT: if current better than best then ...
    best_weights = ...
        
    best_average_episode_return = ...
    
    # END YOUR CODE

    # Generate new random weights
    new_weights = get_new_random_weights_jit(key, best_weights)

    # Bundle weights in RPS Params NamedTuple
    params = RandomPolicySearchParams(current=new_weights, best=best_weights)

    # Increment the learn counter by one
    learn_count = learner_state.count + 1

    return params, LearnerState(learn_count, best_average_episode_return)

In [None]:
#@title Exercise 9 solution
def random_policy_search_learn(key, params, learner_state, memory):
    best_weights = params.best 
    current_weights = params.current

    current_episode_return = memory.average_episode_return
    best_average_episode_return = learner_state.best_average_episode_return

    # Update best_weights and best_average_episode_return
    best_weights = jnp.where(current_episode_return>best_average_episode_return,current_weights,best_weights)    
    best_average_episode_return = jnp.where(current_episode_return>best_average_episode_return,current_episode_return,best_average_episode_return)

    # Generate new random weights
    new_weights = get_new_random_weights_jit(key, best_weights)

    # Bundle weights in RPS Params NamedTuple
    params = RandomPolicySearchParams(current=new_weights, best=best_weights)

    # Increment the learn counter by one
    learn_count = learner_state.count + 1

    return params, LearnerState(learn_count, best_average_episode_return)

Now we can put everything together using the environment loop.

In [None]:
# Jax random number generator
initial_seed=0
rng = hk.PRNGSequence(jax.random.PRNGKey(initial_seed))

initial_learner_state = LearnerState(0, -float("inf"))

# Jit the learn function for some extra speed
random_policy_search_learn_jit = jax.jit(random_policy_search_learn)

initial_weights = np.array([1,1,1,1], "float32")
initial_params = RandomPolicySearchParams(initial_weights, initial_weights)

memory = AverageEpisodeReturnBuffer(num_episodes_to_store=20)

episode_return, evaluator_episode_returns = run_environment_loop(
                                        rng, 
                                        env, 
                                        initial_params, 
                                        random_policy_search_choose_action_jit, 
                                        None, # no actor state
                                        random_policy_search_learn_jit, 
                                        initial_learner_state, 
                                        memory, 
                                        num_episodes=2001
                                    )

# Plot graph of evaluator episode returns
plt.plot(evaluator_episode_returns)
plt.title("Random Search Evaluator Episode Return")
plt.show()

Hopefully, you found a set of optimal parameters on CartPole (if evaluator episode return eventually reaches `500`). If you haven't found optimal parameters, try running the environment loop again, with a different initial seed, you were probably just a little unlucky. That is the big limitation with Random Policy Search after all, if you are unlucky you might not (randomly) stumble on the optimal policy.

So, in Random Policy Search there is very little (if any) real learning going on. We might have been able to find a reasonable policy in this case, but would this work if our policy was in higher dimensions (imagine 100s or 1000s of weights, instead of 4)?   

Next, let's look to implementing a simple RL algorithm instead, that can use its experience to guide the learning process, rather than just randomly sampling new weights. Hopefully, this will help us more reliably find an optimal policy.

## Section 3: Policy Gradients (PG)
As discussed, the goal in RL is to find a policy which maximise the expected cummulative reward (return) the agent receives from the environment. We can write the expected return of a policy as:

$J(\pi_\theta)=\mathrm{E}_{\tau\sim\pi_\theta}\ [R(\tau)]$,

where $\pi_\theta$ is a policy parametrised by $\theta$, $\mathrm{E}$ means *expectation*, $\tau$ is shorthand for "*episode*", $\tau\sim\pi_\theta$ is shorthand for "*episodes sampled using the policy* $\pi_\theta$", and $R(\tau)$ is the return of episode $\tau$.

Then, the goal in RL is to find the parameters $\theta$ that maximise the function $J(\pi_\theta)$. One way to find these parameters is to perform gradient ascent on $J(\pi_\theta)$ with respect to the parameters $\theta$: 

$\theta_{k+1}=\theta_k + \alpha \nabla J(\pi_\theta)|_{\theta_{k}}$,

where $\nabla J(\pi_\theta)|_{\theta_{k}}$ is the gradient of the expected return with respect to the policy parameters $\theta_k$ and $\alpha$ is the step size. This quantity, $\nabla J(\pi_\theta)$, is also called the **policy gradient** and is very important in RL. If we can comput the policy gradient, theat we will have a means by which to directly optimise our policy.

As it turns out, there is a way for us to compute the policy gradient and the mathematical derivation can be found [here](https://spinningup.openai.com/en/latest/spinningup/rl_intro3.html). But for this tutorial we will ommit the derivation and just give you the result:


$\nabla_{\theta} J(\pi_{\theta})=\underset{\tau \sim \pi_{\theta}}{\mathrm{E}}[\sum_{t=0}^{T} \nabla_{\theta} \log \pi_{\theta}(a_{t} \mid s_{t}) R(\tau)]$

Informaly, the policy gradient is equal to the gradient of the log of the probability of the action chosen multiplied by the return of the episode in which the action was taken.


### REINFORCE
REINFORCE is a simple RL algorithm that uses the policy gradient to find the optimal policy by increasing the probability of choosing actions that tend to lead to high return episodes.

**Exercise 10:** Implement a function that takes the probability of an action and the return of the episode the action was taken in and returns the log of the probability multiplied by the return. Make sure you use jax so that we can jit the function.

**Useful functions:**
*   [Jax numpy log](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.log.html)

In [None]:
def compute_weighted_log_prob(action_prob, episode_return):
    weighted_log_prob = None # you will need to overwrite this

    # YOUR CODE


    # END YOUR CODE

    return weighted_log_prob

# TEST: the result should be -22.314354
action_prob = 0.8
episode_return = 100
print("Weighted log prob:", compute_weighted_log_prob(action_prob, episode_return))

In [None]:
#@title Exercise 10 solution

def compute_weighted_log_prob(action_prob, episode_return):
    weighted_log_prob = None # you will need to overwrite this

    # YOUR CODE

    weighted_log_prob = jnp.log(action_prob) * episode_return

    # END YOUR CODE

    return weighted_log_prob

# TEST
action_prob = 0.8
episode_return = 100
print("Weighted log prob:", compute_weighted_log_prob(action_prob, episode_return))


### Rewards-to-go
The gradient of the log of the actions probability, weighted by the return of the episode will tend to push up the probability of actions that were in episodes whith high return, regardless of where in the episode the action was taken. This does not really make much sense because an action near the end of an episode may be reinforced because lots of reward was collected earlier on in the episode, before the action was taken. RL agents should really only reinforce actions on the basis of their *consequences*. Rewards obtained before taking an action have no bearing on how good that action was: only rewards that come after. The cummulative rewards received after an action was taken is called the **rewards-to-go** and can be computed as:

$\hat{R}_t=\sum_t^Tr_t$

Compare the rewards-to-go with the episode return:

$R(\tau)=\sum_{t=0}^Tr_t$

Thus, the policy gradient with rewards-to-go is given by:

$\nabla_{\theta} J(\pi_{\theta})=\underset{\tau \sim \pi_{\theta}}{\mathrm{E}}[\sum_{t=0}^{T} \nabla_{\theta} \log \pi_{\theta}(a_{t} \mid s_{t}) \hat{R}_t]$

**Exercise 11:** Implement a function that takes a list of all the rewards obtained in an episode and computes the rewards-to-go. Don't worry about using jax in this function. You can use regular Python operations like `for-loops`.

In [None]:
# Implement reward to go
def rewards_to_go(rewards):
    """
    This function should take a list of rewards as input and 
    compute the rewards-to-go for each timestep.
    
    Arguments:
        rewards[t] is the reward at time step t.

    Returns:
        rewards_to_go[t] should be the reward-to-go at timestep t.
    """

    rewards_to_go = []

    # YOUR CODE


    # END YOUR CODE

    return rewards_to_go

# TEST: The result should be [10, 9, 7, 4]
rewards = np.array([1,2,3,4])
print("Rewards-to-go:", rewards_to_go(rewards))

In [None]:
#@title Exercise 11 solution

def rewards_to_go(rewards):
    rewards_to_go = []
    for i in range(len(rewards)):
        r2g = 0
        for j in range(i, len(rewards)):
            r2g += rewards[j]
        rewards_to_go.append(r2g)
    return rewards_to_go

# TEST: The result should be [10, 9, 7, 4]
rewards = np.array([1,2,3,4])
print("Rewards-to-go:", rewards_to_go(rewards))

In [None]:
# Faster rewards to go calculation using numpy
def rewards_to_go(rewards):
    return np.flip(np.cumsum(np.flip(rewards)))

# TEST: The result should be [10, 9, 7, 4]
rewards = np.array([1,2,3,4])
print("Rewards-to-go:", rewards_to_go(rewards))

Next we will need to make a new agent memory to store the rewards-to-go $R_t$ along with the observation $o_t$ and action $a_t$ at every timestep.

In [None]:
# Now we need a new episode memory buffer

EpisodeRewardsToGoMemory = collections.namedtuple("AverageEpisodeReturnMemory", ["obs", "action", "reward_to_go"])

class EpisodeRewardsToGoBuffer:

    def __init__(self, num_transitions_to_store=500, batch_size=500):
        self.batch_size = batch_size
        self.memory_buffer = collections.deque(maxlen=num_transitions_to_store)
        self.current_episode_transition_buffer = []

    def push(self, transition):
        self.current_episode_transition_buffer.append(transition)

        if transition.done:

            episode_rewards = []
            for t in self.current_episode_transition_buffer:
                episode_rewards.append(t.reward)

            r2g = rewards_to_go(episode_rewards)

            for i, t in enumerate(self.current_episode_transition_buffer):
                memory = EpisodeRewardsToGoMemory(t.obs, t.action, r2g[i])
                self.memory_buffer.append(memory)

            # Reset episode buffer
            self.current_episode_transition_buffer = []


    def is_ready(self):
        return len(self.memory_buffer) >= self.batch_size

    def sample(self):
        random_memory_sample = random.sample(self.memory_buffer, self.batch_size)

        obs_batch, action_batch, reward_to_go_batch = zip(*random_memory_sample)

        return EpisodeRewardsToGoMemory(
            np.stack(obs_batch).astype("float32"), 
            np.asarray(action_batch).astype("int32"), 
            np.asarray(reward_to_go_batch).astype("int32")
        )


# Instantiate Memory
REINFORCE_memory = EpisodeRewardsToGoBuffer(num_transitions_to_store=512, batch_size=256)

### Policy neural network
Next, we need to aproximate the policy using a simple neural network. Our policy neural network will have an input layer that takes the observation as input and passes it through two hidden layers and then outputs one scalar value for each of the possible actions. So, in CartPole the output layer will have size `2`.

[Haiku](https://github.com/deepmind/dm-haiku) is a library for implementing neural networks is Jax. Below we have implemented a simple function to make the policy network for you. 


In [None]:
def make_policy_network(num_actions: int, layers=[10, 10]) -> hk.Transformed:
  """Factory for a simple MLP network for the policy."""

  def policy_network(obs):
    network = hk.Sequential(
        [
            hk.Flatten(),
            hk.nets.MLP(layers + [num_actions])
        ]
    )
    return network(obs)

  return hk.without_apply_rng(hk.transform(policy_network))

Haiku networks have two important functions you need to know about. The first is the `<network>.init(<rng>, <input>)`, which returns a set of random initial parameters. The second method is the `<network>.apply(<params>, <input>)` which passes an input through the network using the set of parameters provided.

In [None]:
# Example
POLICY_NETWORK = make_policy_network(num_actions=2, layers=[20,20])
random_key = next(rng) # get next random key
dummy_obs = np.array([1,1,1,1], "float32")

REINFORCE_params = POLICY_NETWORK.init(random_key, dummy_obs)
print("Initial params:", REINFORCE_params.keys())

output = POLICY_NETWORK.apply(REINFORCE_params, dummy_obs)
print("Policy network output:", output)


The outputs of our policy network are [logits](https://qr.ae/pv4YTe). To convert this into a probability distribution over actions we pass the logits to the [softmax](https://en.wikipedia.org/wiki/Softmax_function) function (more on this later).

### Action selection

**Exercise 12:** Complete the function below which takes a vector of logits and randomly samples an action. 

**Useful functions:**
*   [Jax random categorical](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.categorical.html)

In [None]:
def sample_action(random_key, logits):
    action = None # you will need to overwrite this
    
    # YOUR CODE HERE


    # END YOUR code

    return action

# TEST
for i in range(10):
  print("Action:", sample_action(next(rng), np.array([0, 1], "float32")))

In [None]:
#@title Exercise 12 solution

def sample_action(random_key, logits):
    action = None # you will need to overwrite this
    
    # YOUR CODE HERE
    action = jax.random.categorical(random_key, logits)
    # END YOUR code

    return action

# TEST: 
for i in range(10):
  print("Action:", sample_action(next(rng), np.array([0, 1], "float32")))


Notice in the tests that the actions are randomly sampled. Ofcourse the action with the higher probability will be chose more often because, but there is always a chance the action with the lower probability will be chosen. This is actually desirable in RL because it mean the agent will always try new things in the environment. We call this **exploring**. Exploring is important because it helps the agent discover new, possibly better, strategies in the environment. When an agent chooses the best possible action (given its current knowledge) we say the agent is being **greedy**.

**Exercise 13:** Complete the function below which takes a vector of logits and returns the greedy action. 

**Useful functions:**
*   [Jax numpy argmax](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.argmax.html)


In [None]:
def greedy_action(logits):
    action = None # you will need to overwrite this

    # YOUR CODE

    # END YOUR CODE

    return action

# TEST
for i in range(10):
    print("Action:", greedy_action(np.array([0, 1], "float32")))

In [None]:
#@title Exercise 13 solution

def greedy_action(logits):
    action = None # you will need to overwrite this

    # YOUR CODE
    action = jnp.argmax(logits)
    # END YOUR CODE

    return action

# TEST
for i in range(10):
    print("Action:", greedy_action(np.array([0, 1], "float32")))


Notice that the greed action selector always chooses the same action, namely the one with the highest probability (or equivalently the largets logit). Next, we will implement the REINFORCE `select_action` function. The function passes the observation through the policy neural network to get loggits and then uses thos to chose an action. If `evaluation` is `True`, the greedy action will be chosen. Otherwise, the action is sampled from the action probability distribution given by the logits.


In [None]:
def REINFORCE_select_action(key, params, actor_state, obs, evaluation=False):
    obs = jnp.expand_dims(obs, axis=0) # add dummy batch dim
    logits = POLICY_NETWORK.apply(params, obs)[0] # remove batch dim

    sampled_action = sample_action(key, logits)

    best_action = greedy_action(logits)

    action = jax.lax.select(
        evaluation,
        best_action,
        sampled_action
    )
    
    return action, actor_state

# TEST
action, actor_state = REINFORCE_select_action(
    key=next(rng),
    params=REINFORCE_params, # we instantiated this earlier
    actor_state=None, # not used
    obs=np.array([1,1,1,1], "float32") # dummy obs
)

print("Action:", action)

Now that we finished the REINFORCE action selection function, all we have left to do is make a REINFORCE learn function. The learn function should use the `weighted_log_prob` function we made earlier to compute the policy gradient and apply the updates to our neural network.

### Network Optimiser

To apply updates to our neural network we will use a Jax library called [Optax](https://github.com/deepmind/optax). Optax has an implementation of the [Adam optimizer](https://www.geeksforgeeks.org/intuition-of-adam-optimizer/) which we can use.

In [None]:
REINFORCE_OPTIMIZER = optax.adam(1e-3)
REINFORCE_optim_state = REINFORCE_OPTIMIZER.init(REINFORCE_params)

### Policy gradient loss

**Exercise 14:** Complete the `pg_loss` function below.

**Useful methods:**
*   [Jax softmax](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.softmax.html)
*   [Jax one-hot vector](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.one_hot.html)
*   [Jax dot product](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.dot.html)

In [None]:
def pg_loss(action, logits, reward_to_go):
    chosen_action_prob = 0.0 # you will need to overwrite this

    # YOUR CODE

    # HINT all_action_probs = ... (convert logits into probs)

    # HINT extract the prob of the desired action.
    # One way to achieve this is to use a one-hot vector and a dot product...?

    # END YOUR CODE
    weighted_log_prob = compute_weighted_log_prob(
                            chosen_action_prob, 
                            reward_to_go
                        )
    
    loss = - weighted_log_prob # negative because we want gradient `ascent`
    
    return loss

# TEST 
print("Policy gradient loss:", pg_loss(0, np.array([0,1], "float32"), 10))

In [None]:
#@title Exercise 14 solution

def pg_loss(action, logits, reward_to_go):

    # YOUR CODE
    
    all_action_probs = jax.nn.softmax(logits)
    action_mask = jax.nn.one_hot(action, logits.shape[0])
    chosen_action_prob = jnp.dot(all_action_probs, action_mask)

    # END YOUR CODE
    
    weighted_log_prob = compute_weighted_log_prob(
                            chosen_action_prob, 
                            reward_to_go
                        )
    
    loss = - weighted_log_prob 
    
    return loss

# TEST 
print("Policy gradient loss:", pg_loss(0, np.array([0,1], "float32"), 10))

Now, when we do a policy gradient update step we are going to want to do it using a batch of experience, rather than just a single experience like above. We can use Jax's [vmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap) function to easily make our `pg_loss` function work on a batch of experience.

In [None]:
def batched_pg_loss(params, obs_batch, action_batch, reward_to_go_batch):
    logits_batch = POLICY_NETWORK.apply(params, obs_batch) # network we made earlier
    pg_loss_batch = jax.vmap(pg_loss)(action_batch, logits_batch, reward_to_go_batch) # add batch
    mean_pg_loss = jnp.mean(pg_loss_batch)
    return mean_pg_loss

# TEST
obs_batch = np.array([[1,0,0,1],[1,0,0,1],[1,0,0,1]])
actions_batch = np.array([1,0,0])
rew2go_batch = np.array([2.3, 4.3, 2.1])

loss = batched_pg_loss(REINFORCE_params, obs_batch, actions_batch, rew2go_batch)

print("PG loss on batch:", loss)

Now we can make the REINFORCE learn function.

In [None]:
REINFORCELearnerState = collections.namedtuple("LearnerState", ["optim_state"])

def REINFORCE_learn(key, params, learner_state, memory):
    
    #Get Policy gradient by using `jax.grad()` on `batched_pg_loss`
    grad_loss = jax.grad(batched_pg_loss)(params, memory.obs, memory.action, memory.reward_to_go)

    # Get param updates using gradient and optimizer
    updates, new_optim_state = REINFORCE_OPTIMIZER.update(grad_loss, learner_state.optim_state)

    # Apply updates to params
    params = optax.apply_updates(params, updates)

    return params, REINFORCELearnerState(new_optim_state) # update learner state

# Lets jit the learn function and the select action function for some extra speed
REINFORCE_learn_jit = jax.jit(REINFORCE_learn)
REINFORCE_select_action_jit = jax.jit(REINFORCE_select_action)

### Training
Now we can train our REINFORCE agent by putting everything together using the environment loop. 

In [None]:
learner_state = REINFORCELearnerState(REINFORCE_optim_state)
actor_state = None # not used

episode_returns, evaluator_returns = run_environment_loop(rng, env, REINFORCE_params, REINFORCE_select_action_jit , actor_state, 
    REINFORCE_learn_jit, learner_state, REINFORCE_memory, num_episodes=1001)

# Plot the episode returns over time
plt.plot(episode_returns)
plt.show()


## Section 4: Q-Learning
Another common aproach to finding an optimal policy in an environment in RL is via Q-learning. 

### State-Action Value function
In Q-learning the agent learns a function that approximates the **value** of state-action pairs. By *value* we mean the return you expect to receive if you start in a particular state $s_t$, take a particular action $a_t$, and then act according to a particular policy $\pi$ forever after. The state-action value function of policy $\pi$ is given by

$Q_\pi(s,a)=\mathrm{E}_{\tau\sim\pi}\left[R(\tau) \mid s_0=s,\ a_t=a\right]$.

We say that the value function $Q_\pi(s,a)$ is the **optimal** value function if the policy $\pi$ is an optimal policy. We denote the optimal value function as follows:

$Q_\ast(s,a)=\max \limits_\pi \  \mathrm{E}_{\tau\sim\pi}\left[R(\tau) \mid s_0=s,\ a_0=a\right]$

There is an important relationship between the optimal action $a_\ast$ in a state $s$ and the optimal state-action value function $Q_\ast$. Namely, the optimal action $a_\ast$ in state $s$ is equal to the action that maximises the optimal state-action value function. This relationship naturally induces an optimal policy:

$\pi_\ast(s)=\arg \max \limits_a\ Q_\ast(s, a)$

This kind of policy is an example of a **greedy** policy. It is an optimal action selection strategy when the Q-function is optimal. However, since we approximate the Q-function, the resulting greedy policy may not be optimal. In some environments, such a policy still yields decent behavior, however, most environments require an **exploration** strategy.

### Greedy action selection

**Exercise 15:** Let's implement a function that, given a vector of Q-values, returns the action with the largest Q-value (i.e. the greedy action).

**Useful methods:**
*   [Jax argmax](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.argmax.html)

In [None]:
# Implement a function takes q-values as input and returns the greedy_action
def select_greedy_action(q_values):
    action = None # you will need to overwrite this

    # YOUR CODE

    # END YOUR CODE

    return action

In [None]:
# @title Check correctness of implementation (Run me) 

def check_select_greedy_action(select_greedy_action):
  q_values = jnp.array([1,1,3,4])
  action = select_greedy_action(q_values)

  assert action == 3, "Incorrect answer, your greedy action selector looks wrong"

  print("Looks good.")

check_select_greedy_action(select_greedy_action)


In [None]:
#@title Exercise 15 solution

def select_greedy_action(q_values):
    action = None # you will need to overwrite this
    
    # YOUR CODE
    action = jnp.argmax(q_values)
    # END YOUR CODE

    return action

### Q-Network
Unlike in the policy gradient approaches from the previous section, Q-learning and other value-based reinforcement learning approaches don't need a direct parameterisation for the policy. This is because approximated Q-function implicitly stores a policy which can be recovered using: 

$\hat{\pi}(s)=\arg \max \limits_a\ Q_{\theta}(s, a)$

As we did previously, we shall use haiku to make a neural network to approximate this Q-function. The network will take an observation as input and then output a Q-value for each of the available actions. So in the case of CartPole, the output of the network will have size $2$.

In [None]:
def build_network(num_actions: int, layers=[10, 10]) -> hk.Transformed:
  """Factory for a simple MLP network for approximating Q-values."""

  def q_network(obs):
    network = hk.Sequential(
        [hk.Flatten(),
         hk.nets.MLP(layers + [num_actions])])
    return network(obs)

  return hk.without_apply_rng(hk.transform(q_network))

Let's initialise our Q-network and get the initial weights.

In [None]:
# Initialise Q-network
Q_NETWORK = build_network(num_actions=2) # two actions

dummy_obs = jnp.zeros((1,4), jnp.float32) # a dummy observation like the one in CartPole

Q_NETWORK_WEIGHTS = Q_NETWORK.init(next(rng), dummy_obs) # Get initial weights

print("Q-Learning params:", Q_NETWORK_WEIGHTS.keys())

Before we implement the loss function required for training our Q-network. Let's first discuss the intuition behind it. 

### The Bellman Equations
The value function can be written recursively as:

$Q_{\pi}(s, a) =\underset{s^{\prime} \sim P}{\mathrm{E}}\left[r(s, a)+\gamma \underset{a^{\prime} \sim \pi}{\mathrm{E}}\left[Q_{\pi}\left(s^{\prime}, a^{\prime}\right)\right]\right]$,

where $s' \sim P$ is shorthand for saying that the next state $s'$ is sampled from the environment’s transition function $P(s'\mid s,a)$. Intuitively, this equation says that the value of the state you are in is equal to the reward you expect to get from being there, plus the "average" value across the possible actions in the state you transition to next. The Bellman equation for the optimal value function is:

$Q_{*}(s, a) =\underset{s^{\prime} \sim P}{\mathrm{E}}\left[r(s, a)+\gamma\ \underset{a^{\prime}}{\max}\ Q_{*}(s^{\prime}, a^{\prime})\right]$

Notice that in this version of the state-action value function the "average" value across actions in the neighbouring state $s^{\prime}$ is replaced by the value of the action in $s^{\prime}$ with the largest action-value.


For a more in-depth discussion of the Bellman Equations, see the [OpenAI Spinning Up](https://spinningup.openai.com/en/latest/spinningup/rl_intro.html) website.

### The Bellman Backup
To learn to approximate the optimal Q-value function, we can use the right-hand side of the Bellman equation as an update rule. In other words, suppose we have a Q-function $Q_\theta$ approximated using parameters $\theta$ then we can iteratively update the parameters such that

$Q_\theta(s,a)\leftarrow r(s, a) + \gamma \underset{a'}{\max}\ Q_\theta(s', a')$.

Intuitively, this says that the approximation of the Q-value of action $a$ in state $s$ should be updated such that it is closer to being equal to the reward received from the environment $r(s, a)$ plus the value of best possible action in the next state $s'$. We can perform this optimisation by minimising the difference between the left and right-hand side, with respect to the parameters $\theta$ using gradient descent. We can measure the difference between the two values using the [squared-error](https://en.wikipedia.org/wiki/Mean_squared_error#Loss_function).

**Exercise 16:** Implement the squared-error function.

In [None]:
def compute_squared_error(pred, target):
  # YOUR CODE
  squared_error = ...
  # END YOUR CODE

  return squared_error

In [None]:
#@title Exercise 16 solution
def compute_squared_error(pred, target):
    squared_error = None

    # YOUR CODE
    squared_error = jnp.square(pred - target)
    # END YOUR CODE
    return squared_error

**Exercise 17:** Implement a function that computes the **Bellman target** (right-hand side of the Bellman equation). If the episode is at the last timestep (i.e. done==1.0), then the Bellman target should be equal to the reward, with no extra value at the end.  

Assume a discount factor of 1, ($\gamma = 1$)

In [None]:
# Bellman target
def compute_bellman_target(reward, done, next_q_values):
  """A function to compute the bellman target.
  
  Args:
      reward: a scalar reward.
      done: a scalar of value either 1.0 or 0.0, indicating if the transition is a terminal one.
      next_q_values: a vector of q_values for the next state. One for each action.
  Returns:
      A scalar equal to the bellman target.
  
  """
  # YOUR CODE
  bellman_target = ...
  # END YOUR CODE

  return bellman_target

In [None]:
#@title Exercise 17 solution

# Bellman target
def compute_bellman_target(reward, done, next_q_values):
    """A function to compute the bellman target.
    
    Args:
        reward: a scalar reward.
        done: a scalar of value either 1.0 or 0.0, indicating if the transition is a terminal one.
        next_q_values: a vector of q_values for the next state. One for each action.
    Returns:
        A scalar equal to the bellman target.
    
    """
    # YOUR CODE
    bellman_target = reward + (1.0 - done) * jnp.max(next_q_values)
    # END YOUR CODE

    return bellman_target


We can now combine these two functions to compute the loss for Q-learning. The Q-learning loss is equal to the squared difference between the predicted Q-value of an action and its corresponding Bellman target.

**Exercise 18:** Implement the Q-learning loss.

In [None]:
def q_learning_loss(q_values, action, reward, done, next_q_values):
    """Implementation of the Q-learning loss.T
    
    Args:
        q_values: a vector of Q-values, one for each action.
        action: an integer, giving the action that was chosen. q_values[action] is the value of the chose action.
        done: is a scalar that indicates if this is a terminal transition.
        next_q_values: a vector of Q-values in the next state.
    Returns:
        The squared difference between the q_value of the chosen action and the bellman target.
    """
    # YOUR CODE
    chosen_action_q_value = ...
    bellman_target = ...
    squared_error = ...
    # END YOUR CODE
    
    return squared_error

In [None]:
#@title Exercise 18 solution

def q_learning_loss(q_values, action, reward, done, next_q_values):
    """Implementation of the Q-learning loss.T
    
    Args:
        q_values: a vector of Q-values, one for each action.
        action: an integer, giving the action that was chosen. q_values[action] is the value of the chose action.
        done: is a scalar that indicates if this is a terminal transition.
        next_q_values: a vector of Q-values in the next state.
    Returns:
        The squared difference between the q_value of the chosen action and the bellman target.
    """
    # YOUR CODE
    chosen_action_q_value = q_values[action]
    bellman_target = compute_bellman_target(reward, done, next_q_values)
    squared_error = compute_squared_error(chosen_action_q_value, bellman_target)
    # END YOUR CODE
    
    return squared_error

### Target Q-network
Notice that when we compute the bellman target we are using our Q-network $Q_\theta$ to compute the value for the next state $s_t$. We are basically using our latest approximation of the Q-function to compute the target of our next approximation. Using an approximation to compute the target for your next approximation, is called bootstrapping. Unfortunately, if we naively bootstrap like this, it can make training a neural network very unstable. To mitigage this we can instead use a different set of parameters $\hat{\theta}$ to compute the values at state $s_{t+1}$. We will keep the parameters $\hat{\theta}$ fixed and only periodically update them to be equal to the latest online parameters $\theta$ every couple of training steps *(say 100)*. This serves to keep the bellman targets fixed for a couple training steps to help reduce the instability due to bootstrapping. 


We will need to keep track of the latest (online) parameters, as well as the target networks parameters. Lets make a `namedtuple` to store these two values. We will also need to keep track of the number of learner steps we have taken, so that we know when to update the target network.

In [None]:
QLearnerState = collections.namedtuple("LearnerState", ["count", "optim_state"])
QLearnerParams = collections.namedtuple("Params", ["online", "target"])

We will once again be using Optax to optimize our neural network in Jax. Here we instantiate the optimizer and add the initial Q-network weights to a `QLearnerParams` object.

In [None]:
# Initialise Q-network optimizer
Q_LEARN_OPTIMIZER = optax.adam(1e-3) # learning rate = 0.001

Q_LEARN_OPTIM_STATE = Q_LEARN_OPTIMIZER.init(Q_NETWORK_WEIGHTS) # initial optim state

# Create Learn State
Q_LEARNING_LEARN_STATE = QLearnerState(0, Q_LEARN_OPTIM_STATE) # count set to zero initially

# Add initial Q-network weights to QLearnerParams object
Q_LEARNING_PARAMS = QLearnerParams(online=Q_NETWORK_WEIGHTS, target=Q_NETWORK_WEIGHTS) # target equal to online

Now we can implement a simple function that updates the learners parameters every 100 training steps.

In [None]:
def update_target_params(learn_state, online_weights, target_weights):

  target = jax.lax.cond(
      jax.numpy.mod(learn_state.count, 100) == 0,
      lambda x, y: x,
      lambda x, y: y,
      online_weights, 
      target_weights
  )

  params = QLearnerParams(online_weights, target)

  return params

### Q-learning loss
We now have everything we need to implement the `q_learn_step` function which takes some batch of transitions and does a step of Q-learning to update the network paramters. But first we use `jax.vmap` to modify the `q_learning_loss` function so that it accepts batches of transitions. In addition, we will compute the Q-values by passing the observations through the `Q_NETWORK` and the target Q-values using the target parameters of the `Q_Network`.

In [None]:
def batched_q_learning_loss(online_params, target_params, obs, actions, rewards, next_obs, dones):
    q_values = Q_NETWORK.apply(online_params, obs)
    next_q_values = Q_NETWORK.apply(target_params, next_obs)
    squared_error = jax.vmap(q_learning_loss)(q_values, actions, rewards, dones, next_q_values) # vmap
    mean_squared_error = jnp.mean(squared_error) # mean squared error
    return mean_squared_error

Now we can create the `q_learner_step` function which computes the gradient of the `batched_q_learning_loss` and then uses an Optax optimizer to update the network weights and then finally (maybe) updates the target parameters.

In [None]:
def q_learner_step(rng, params, learner_state, memory):
  # Compute gradients
  grad_loss = jax.grad(batched_q_learning_loss)(params.online, params.target, memory.obs, 
                                          memory.action, memory.reward, 
                                          memory.next_obs, memory.done,
                                          )

  # Get updates
  updates, opt_state = Q_LEARN_OPTIMIZER.update(grad_loss, learner_state.optim_state)

  # Apply them
  new_weights = optax.apply_updates(params.online, updates)

  # Maybe update target network
  params = update_target_params(learner_state, new_weights, params.target)

  # Increment learner step counter
  learner_state = QLearnerState(learner_state.count + 1, opt_state)

  return params, learner_state

### Replay Buffer
For Q-learning we will need an agent memory that stores entire transitions: `obs`, `action`, `reward`, `next_obs`, `done`. When we retrieve transitions from the memory, they should be chosen randomly. In RL we often call such a module a **replay buffer**.

In [None]:
class TransitionMemory(object):
  """A simple Python replay buffer."""

  def __init__(self, max_size=5000, batch_size=256):
    self.batch_size = batch_size
    self.buffer = collections.deque(maxlen=max_size)

  def push(self, transition):

      self.buffer.append(
          (transition.obs, transition.action, transition.reward, 
           transition.next_obs, transition.done)
      )

  
  def is_ready(self):
    return self.batch_size <= len(self.buffer)

  def sample(self):
    random_replay_sample = random.sample(self.buffer, self.batch_size)
    obs_batch, action_batch, reward_batch, next_obs_batch, done_batch = zip(*random_replay_sample)

    return Transition(
        np.stack(obs_batch).astype("float32"), 
        np.asarray(action_batch).astype("int32"), 
        np.asarray(reward_batch).astype("float32"), 
        np.stack(next_obs_batch).astype("float32"), 
        np.asarray(done_batch).astype("float32")
    )

Q_LEARNING_MEMORY = TransitionMemory(max_size=5000, batch_size=256)

### Random exploration
We almost have everything we need for a functioning Q-learning agent. But one problem is that if we always choose the action with the highest Q-value as our policy then the agent's policy will be completly deterministic. This means the agent will always choose the same strategy. This can pose a problem because at the start of training, the Q-network will be very inaccurate (i.e. a bad aproximation of the true Q-function). As such, the agent will consistently choose suboptimal actions. Moreover, the agent will never deviate from its suboptimal strategy and will never discover new, potentially more rewarding  actions. As a result, the Q-network remains inaccurate. Ideally, the agent should try out many different strategies so that it can observe the outcomes (rewards) of its actions in different states and so improve its approximation of the Q-function.

One easy way to ensure that the agent tries out many different actions is to let it periodically choose some random actions, instead of the greedy (best) action all the time.

**Exercise 19:** Implement a function that, given the number of possible (discrete) actions, returns a random action.

**Useful methods:**

*   [Jax random int](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.randint.html)

In [None]:
def select_random_action(key, num_actions):
    
    # YOUR CODE
    action = ...
    # END YOUR CODE

    return action

# TEST
for i in range(10):
  print(f"Random action number {i}: {select_random_action(next(rng), 2)}")

In [None]:
#@title Exercise 19 solution

def select_random_action(key, num_actions):
    # YOUR CODE
    action = jax.random.randint(
        key, 
        shape=(), 
        minval=0, 
        maxval=num_actions
    )
    # END YOUR CODE

    return action

# TEST
for i in range(10):
  print(f"Random action number {i}: {select_random_action(next(rng), 2)}")

### $\varepsilon$-greedy action selection
At the start of training, when the accuracy of the Q-network is low, it is worthwhile for the agent to mostly take random actions. However, as the accuracy of the Q-network improves, the agent should start taking fewer random actions and instead start choosing the greedy actions with respect to the Q-values. Choosing actions from the current implicit or explicit policy is referred to as **exploitation.** In RL we often call the ratio of random to greedy actions **epsilon** $\varepsilon$. Epsilon is usually a decimal value in the interval $[0,1]$, where for example $\varepsilon=0.4$ means that the agent chooses a random action 40% of the time and the greedy action 60% of the time. It is common in RL to linearly decrease the value of epsilon over time so that the agent becomes increasingly greedy as the accuracy of its Q-network improves through learning.


**Exercise 17:** Implement a function that takes the number of timesteps as input and returns the current epsilon value.

In [None]:
EPSILON_DECAY_TIMESTEPS = 1_000
EPSILON_MIN = 0.1 # 10% exploration

In [None]:
def get_epsilon(num_timesteps):
  # YOUR CODE
  epsilon = ...
  # END YOUR CODE

  return epsilon

# TEST

print("Epsilon after 10 timesteps:", get_epsilon(10))
print("Epsilon after 10 010 timesteps:", get_epsilon(5_010))


In [None]:
#@title Exercise 20 solution

def get_epsilon(num_timesteps):

  # YOUR CODE
  epsilon = 1.0 - num_timesteps / EPSILON_DECAY_TIMESTEPS

  epsilon = jax.lax.select(
      epsilon < EPSILON_MIN,
      EPSILON_MIN,
      epsilon
  )
  # END YOUR CODE

  return epsilon

# TEST

print("Epsilon after 10 timesteps:", get_epsilon(10))
print("Epsilon after 10 010 timesteps:", get_epsilon(5_010))


**Exercise 21:** Now lets put these functions together to do epsilon-greedy action selection.

In [None]:
def select_epsilon_greedy_action(key, q_values, num_timesteps):    
    # YOUR CODE HERE
    action = ...
    # END YOUR CODE

    return action

# TEST
dummy_q_values = jnp.array([0,1], jnp.float32)
num_timesteps = 5010 # very greedy
print("Greedy actions:", end=" ")
for i in range(10):
    print(select_epsilon_greedy_action(next(rng), dummy_q_values, num_timesteps), end=" ")
print()

num_timesteps = 0 # completly random
print("Random actions:", end=" ")
for i in range(10):
    print(select_epsilon_greedy_action(next(rng), dummy_q_values, num_timesteps), end=" ")

In [None]:
#@title Exercise 21 solution

# Now make a function that takes an epsilon-greedy action

def select_epsilon_greedy_action(key, q_values, num_timesteps):

    epsilon = get_epsilon(num_timesteps)

    should_explore = jax.random.uniform(key, (1,))[0] < epsilon

    num_actions = len(q_values)

    action = jax.lax.select(
        should_explore,
        select_random_action(key, num_actions), 
        select_greedy_action(q_values)
    )

    return action

# TEST
dummy_q_values = jnp.array([0,1], jnp.float32)
num_timesteps = 5010 # very greedy
print("Greedy actions:", end=" ")
for i in range(10):
    print(select_epsilon_greedy_action(next(rng), dummy_q_values, num_timesteps), end=" ")
print()

num_timesteps = 0 # completly random
print("Random actions:", end=" ")
for i in range(10):
    print(select_epsilon_greedy_action(next(rng), dummy_q_values, num_timesteps), end=" ")

### Q-learning select action

We now have everything we need to make the `q_learning_select_action` function. We will use the `actor_state` to store a counter which keeps track of the current number of timesteps.

In [None]:
# Actor state stores the current number of timesteps
QActorState = collections.namedtuple("ActorState", ["count"])

def q_learning_select_action(key, params, actor_state, obs, evaluation=False):
    obs = jnp.expand_dims(obs, axis=0) # add dummy batch dim
    q_values = Q_NETWORK.apply(params.online, obs)[0] # remove batch dim

    action = select_epsilon_greedy_action(key, q_values, actor_state.count)
    greedy_action = select_greedy_action(q_values)

    action = jax.lax.select(
        evaluation,
        greedy_action,
        action
    )

    next_actor_state = QActorState(actor_state.count + 1) # increment timestep counter

    return action, next_actor_state

Q_LEARNING_ACTOR_STATE = QActorState(0) # counter set to zero

### Training
We can now put everything together using the agent-environment loop. But first,lets jit the select action function and the learn function for some extra speed.

In [None]:
# Jit functions
q_learning_select_action_jit = jax.jit(q_learning_select_action)
q_learner_step_jit = jax.jit(q_learner_step)

# Initialise memory
memory = TransitionMemory(10_000, 512) # store 10000 transitions

# Run environment loop
episode_returns, evaluator_returns = run_environment_loop(
                                        rng, 
                                        env, 
                                        Q_LEARNING_PARAMS, 
                                        q_learning_select_action_jit, 
                                        Q_LEARNING_ACTOR_STATE,
                                        q_learner_step_jit, 
                                        Q_LEARNING_LEARN_STATE, 
                                        memory,
                                        num_episodes=1_001,
                                        learn_steps_per_episode=16
                                    )

plt.plot(episode_returns)
plt.show()

At this stage, the approximated Q-function hopefully converged to a decent (implicit) policy for balancing the pole in the CartPole problem. 

This section attempts to summarise [Playing Atari with Deep Reinforcement Learning](https://arxiv.org/abs/1312.5602), the research paper where Deep-Q Learning was first introduced. To understand the concepts covered in this section better, we recommend you give it a read.

## Conclusion
**Summary:**

In this practical we learnt the basics of reinforcement learning (RL).

In the first section we learnt some basic concepts such as environment observations, action selection strategies, rewards, and episodes. We learnt about rewards and that the goal in RL is to learn a policy which maximises some notion of cummulative reward that the agent receives from the environment (return). 

In the second section we searched for an optimal policy in CartPole using an algorithm called RandomSearch. Basically, we tried out different policies until we happened to find one that worked well. This method did not yield consistent results and success required immense luck.

In the third section we learnt about policy gradients and how we can use gradient ascent to adjust the parameters in our agents policy in the direction which maximises the expected cummulative reward (return).

Finally, in the fourth section we learnt about the state-action value function and how it is related to an optimal policy. We implemented an algorithm called Q-learning to learn the optimal state-action value function in CartPole. We learnt about the importance of using a target network and epsilon-greedy exploration.

**Next Steps:** 

Now that you have successfully solved CartPole with two different RL algorithms, REINFORCE and Deep Q-Learning, we now encourage you to use what you have learnt to try and solve some more challenging environments. OpenAI Gym is a great place to find RL environments. [LunarLander](https://www.gymlibrary.ml/environments/box2d/lunar_lander/) is a great next step.

In addition, there are many RL algorithms out there that make significant improvements to REINFORCE and Deep Q-Learning. See these resources:
* [REINFORCE with baseline](https://spinningup.openai.com/en/latest/spinningup/rl_intro3.html#baselines-in-policy-gradients)
* [Double Deep Q-Network](https://arxiv.org/pdf/1509.06461.pdf)
* [Proximal Policy Optimisation (PPO)](https://arxiv.org/pdf/1707.06347.pdf)

**Appendix:** 

N/a

**References:** 

* [OpenAI Spinning Up](https://spinningup.openai.com/en/latest/)
* [Deep Q-Network]()

For other practicals from the Deep Learning Indaba, please visit [here](https://github.com/deep-learning-indaba/indaba-pracs-2022).

## Feedback

Please provide feedback that we can use to improve our practicals in the future.

In [None]:
#@title Generate Feedback Form. (Run Cell)
from IPython.display import HTML
HTML(
"""
<iframe 
	src="https://forms.gle/bvLLPX74LMGrFefo9",
  width="80%" 
	height="1200px" >
	Loading...
</iframe>
"""
)

<img src="https://baobab.deeplearningindaba.com/static/media/indaba-logo-dark.d5a6196d.png" width="50%" />