# Intro to Reinforcement Learning


[Need to update this picture]

<a href="https://colab.research.google.com/github/deep-learning-indaba/indaba-pracs-2022/blob/main/Indaba_2022_Prac_Template.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a> [Change colab link to point to prac.]

© Deep Learning Indaba 2022. Apache License 2.0.

**Authors:**
Claude Formanek

**Introduction:** 

Reinforcement Learning

**Topics:** 

Content: Reinforcement Learning

Level: Beginner

**Aims/Learning Objectives:**

[Points on the exact learning outcomes from the prac. This should probably be <=5 things.]

**Prerequisites:**

[Knowledge required for this prac. You can link a relevant parallel track session, blogs, papers, courses, topics etc.]

**Outline:** 

[Points that link to each section.]

**Before you start:**

For this practical, you will need to use a GPU to speed up training. To do this, go to the "Runtime" menu in Colab, select "Change runtime type" and then in the popup menu, choose "GPU" in the "Hardware accelerator" box.


## Setup

For this practical, it will help to use a GPU runtime to speed up training. To do this, go to the `Runtime` menu in Colab, select `Change runtime type` and then in the popup menu, choose `GPU` in the `Hardware accelerator` box.

In [None]:
# @title Install required packages (Run Cell)
%%capture
!pip install jaxlib
!pip install jax
!pip install git+https://github.com/deepmind/dm-haiku
!pip install gym
!pip install gym[box2d]
!pip install rlax
!pip install optax

In [None]:
# @title Import required packages (Run Cell)
%%capture
import random
import collections # useful data structures
import numpy as np
import gym # reinforcement learning environments
import jax
import jax.numpy as jnp # jax numpy
import haiku as hk # jax neural network library
import optax # jax optimizer library
import rlax # jax reinforcement learning library
import matplotlib.pyplot as plt



## Section 1: Key Concepts in Reinforcement Learning

Reinforcement Learning (RL) is a subfield of Machine Learning (ML). RL algorithms try to learn the optimal actions to take in an environment in order to maximise some reward signal. More precicely, in RL we have an **agent** which percieves an observation $o_t$ of the current state $s_t$ of the **environment** and must choose an action $a_t$ to take. The environment then transitions to a new state $s_{t+1}$ in response to the agents action and also gives the the agent a scalar reward $r_t$ to indicate how good or bad the chosen action was given the environment's state. The goal in RL is for the agent to maximise the ammount of reward it receives from the environment over time. The subscipt $t$ is used to indicate the timestep number, i.e., $s_0$ is the state of the environment at the initial timestep and $a_{99}$ is the agent's action at the $99th$ timestep. 

### OpenAI Gym
OpenAI has provided a Python package called Gym that includes implementations of popular environments and a simple interface for an RL agent to interact with. To create a gym environment, all you need to do is pass the name of the environment to the function `gym.make(<environment_name>)`. In this tutorial we will be using a simple environment called CartPole. In **CartPole** the task is for the agent to learn to balance the pole for as long as possible by moving the cart *left* or *right*.

<img src="https://miro.medium.com/max/600/1*v8KcdjfVGf39yvTpXDTCGQ.gif" width="30%" />

In [None]:
# Create the environment
env = gym.make("CartPole-v1")

### States and observations
In RL the agent perceives an observation of environments state. In some settings the observation may include all the information underlying the environment's state. Such an envrionment is called **fully observed**. In other settings the the agent may only receive partial information about the environment's state in its observation. Such an environment is called **partially observed**. For the rest of this tutorial we will assume the environment is fully observed and so we will use state $s_t$ and observation $o_t$ interchangeably. In Gym we get the initial observation from the environment by calling the function `env.reset()`: 

In [None]:
# Reset the environment
s_0 = env.reset()
print("Initial State::", s_0)

In CartPole the state of the environment is represented by four numbers; *angular position of the pole, angular velocity of the pole, position of the cart, velocity of the cart*. 

### Actions
In RL actions are usually either **discrete** or **continuous**. Continous actions are given by a vector of real numbers. Discrete actions are given by an integer value. In environments where we can count out the finite set of actions we usually use discrete actions. In CartPole there are only two actions; *left and right*. As such, the actions can be represented by integers $0$ and $1$. In gym we can easily get the list of possible actions:

In [None]:
# Get action space
print(f"Environment action space: {env.action_space}")

# Get num actions
num_actions = env.action_space.n
print(f"Number of actions: {num_actions}")

### The agent's policy
In RL the agent choses actions based on the observation it receives. We can think of the agent's action selection process as a function that takes an observation as input and returns an action as output. In RL we usually call this function the agent's **policy** and denote it $\pi(s_t)=a_t$. 

In some cases we may want our policy to be stocastic, rather than deterministic. In such cases, actions are actiually randomly sampled from a probability distribution that is conditiond on the observation. We denote stocastic policies $a_t\sim\pi(\cdot\ |\ s_t)$, where the symbol $\cdot$ is simply a shorthand for *all actions* and $~$ means "*sampled from*".

**Exercise 1:** If Bob has a deterministic policy $\pi$ and the first time Bob uses his policy on some observation $o_t$ he choses action $0$. What will Bob's chosen action be if he uses his policy a second time on the exact same observation $o_t$. Chose from the options below and assume there are only two possible actions:
1.   Bob will chose action 0 again.
2.   Bob will chose action 1.
3.   You can't say.

**Exercise 2:**  If Alice has a stocastic policy $\pi$ and the first time Alice uses her policy on some observation $o_t$ she choses action $0$. What will Alice's chosen action be if she uses her policy a second time on the exact same observation $o_t$. Chose from the options below and assume there are only two possible actions:
1.   Alice will chose action 0 again.
2.   Alice will chose action 1.
3.   You can't say.

In [None]:
#@title Exercise 1 & 2 solution
%%capture

""" 
Exercise 1: Since the policy is deterministic, Bob will always choose the same 
action given the same observation.

Exercise 2: Since the policy is stocastic, the result from Alice's policy will be
random. So you can't say.
"""

As an exercise we will implement a stocastic policy as well as a deterministic policy for CartPole. 

**Exercise 3:** Complete the function below which should take an observation `obs` as input, compute the dot product between the `obs` and parameter vector `params=[1,-2,2,-1]` and then return a `0` if the result is less than or equal to zero and a `1` if the result is greater than zero. Is this a deterministic or stocastic policy?

**Usefull methods:** 
*   [Numpy dot product](https://numpy.org/doc/stable/reference/generated/numpy.dot.html)


**Notes:**

*   We already imported `numpy` as `np`.
*   Assume `obs` is also a vector of four numbers like `params`.

In [None]:
def choose_action(obs):

  # YOUR CODE


  # END YOUR CODE

  return action

In [None]:
#@title Exercise 3 solution

def choose_action(obs):

  weights = np.array([1,-2,2,-1])
  dot_product = np.dot(obs, weights)

  if dot_product >= 0:
    action = 0
  else:
    action = 1

  return action

# TESTS
fixed_obs = [1,1,2,4]
print(f"Fixed observation: {fixed_obs}")
for i in range(10):
  print(f"Result of policy call number {i}: {choose_action(fixed_obs)}")



**Exercise 4:** Complete the function below which should take an observation `obs` as input, compute the dot product between the `obs` and parameter vector `params=[1,-2,2,-1]` and then return a `0` 20% of the time and a `1` 80% of the time if the result is less than or equal to zero. If the result of the dot-product is greater than zero the function should return a `0` 100% of the time. Is this a deterministic or stocastic policy?

**Usefull methods:** 
*   [Numpy random choice](https://numpy.org/doc/stable/reference/random/generated/numpy.random.choice.html)

In [None]:
def choose_action(obs):

  weights = np.array([1,-2,2,-1])
  action = None # you will need to overwrite this

  # YOUR CODE


  # END YOUR CODE

  return action

# TESTS
fixed_obs = [1,1,2,4]
print(f"Fixed observation: {fixed_obs}")
for i in range(10):
  print(f"Result of policy call number {i}: {choose_action(fixed_obs)}")

In [None]:
#@title Exercise 4 solution

def choose_action(obs):

  weights = np.array([1,-2,2,-1])
  dot_product = np.dot(obs, weights)

  if dot_product <= 0:
    action = np.random.choice([0,1], p=[0.2, 0.8])
  else:
    action = 0

  return action

# TESTS
fixed_obs = [1,1,2,4]
print(f"Fixed observation: {fixed_obs}")
for i in range(10):
  print(f"Result of policy call number {i}: {choose_action(fixed_obs)}")



### The environment transition function
Now that we have a policy we can pass actions from the agent to the environment. The environment will the transition into a new state in response to the agent's action. In RL we model this process by a **state transition function** $P$ which takes the current state $s_t$ and an action $a_t$ as input and returns the next state $s_{t+1}$ as output. Like with policies, the state transiton function can either be deterministic $s_{t+1}=P(s_t, a_t)$ or it can be stocastic $s_{t+1}\sim P(\cdot\ |\ s_t, a_t)$. In gym we can pass actions to the environment by calling the `env.step(<action>)` function. The function will then return four values; the next observation, the reward for the action taken, a boolean flag to indicate if the game is done and finally some extra information.


In [None]:
# Get the initial obs by resetting the env
initial_obs = env.reset()

# Use the policy to chose an action
action = choose_action(initial_obs)

# Step the environment
next_obs, reward, done, info = env.step(action)

print("Observation:", initial_obs)
print("Action:", action)
print("Next observation:", next_obs)
print("Reward:", reward)
print("Game is done:", done)

### Episode return
In RL we usually break the agents interactions with the environment up into **episodes**. The sum of all reward collected during an episode is what we call the episode's **return**. The goal in RL is for the agent to chose actions which maximise the expected future return. In CartPole the agent receives a reward of `1` for every timestep the pole is still upright. If the pole falls over, the game is over and the agent receives a reward of `0`.

### Agent-environment Loop
Now that we know what a policy is and we know how to step the environment, lets close the agent-environment loop.

**Exercise 5:** Write a function which runs one episode of CartPole by sequentially choosing actions and stepping the environment. You should use the stocastic policy we defined earlier to chose actions. The function should keep track of the reward received and output the return at the end of the episode.

In [None]:
def run_episode(env):
  episode_return = 0

  ## YOUR CODE

  # HINT: reset environment

  # HINT: while loop until episode is done

    # HINT: choose action

    # HINT: step environment
    
    # HINT: add reward to episode_return

  ## END CODE

  return episode_return

In [None]:
#@title Exercise 5 solution

def run_episode(env):
  episode_return = 0
  
  obs = env.reset()
  done = False

  while not done:
    action = choose_action(obs)

    next_obs, reward, done, info = env.step(action)

    episode_return += reward

    # Critical
    obs = next_obs

  return episode_return

# TEST
print("Episode return:", run_episode(env))

In CartPole, the environment is considered solved when the agent can reliably achieve an episode return of 500. As you can see, our current policy is nowhere near optimal yet. Lets learn a way to find an optimal policy.

One way we can find an optimal policy is by randomly searching for it. Obviously in a complex environment finding an optimal policy by randomly trying differnet strategies could take forever. But CartPole is a sufficiently simple environment that it might just work.

Before we implement random policy search, lets take a look at the following environment loop function that we implemented for you. Its the environmentloop we will be using for the rest of the notebook. We will use a [NamedTuple](https://www.geeksforgeeks.org/namedtuple-in-python/) to bundle `obs`, `action`, `reward`, `next_obs` and the done flag into a **transition** object.

In [None]:
# Named tuple to store transition
Transition = collections.namedtuple("Transition", ["obs", "action", "reward", "next_obs", "done"])

# TEST
transition = Transition(
    obs=[1,2,-1,2],
    action=0,
    reward=10,
    next_obs=[1,2,2,1],
    done=True
)

print("Transition obs:", transition.obs)
print("Transition action:", transition.action)
print("Transition reward:", transition.reward)
print("Transition next obs:", transition.next_obs)
print("Transition done:", transition.done)

Below is the environment loop function we have implemented for you. We reccomend reading through it and trying to understand it, but if anything is unclear, don't worry about it. It should all make more sense as we work through this tutorial.

In [None]:
# Environment loop
def run_environment_loop(rng, env, agent_params, agent_select_action_func, 
    agent_actor_state=None, agent_learn_func=None, agent_learner_state=None, 
    agent_memory=None, num_episodes=1000, evaluator_period=100, 
    evaluation_episodes=32, learner_steps_per_episode=1):
    """
    This function runs several episodes in an environment and periodically does 
    some agent learning and evaluation.
    
    Args:
        rng: a random number generator. This is for jax.
        env: a gym environment.
        agent_params: an object to store parameters that the agent uses.
        agent_select_func: a function that does action selection for the agent.
        agent_actor_state (optional): an object that stores the internal state 
            of the agents action selection function.
        agent_learn_func (optional): a function that does some learning for the 
            agent by updating the agent parameters.
        agent_learn_state (optional): an object that stores the internal state 
            of the agent learn function.
        agent_memory (optional): an object for storing an retrieving historical 
            experience.
        num_episodes: how many episodes to run.
        evaluator_period: how often to run evaluation.
        evaluation_episodes: how many evaluation episodes to run.

    Returns:
        episode_returns: list of all the episode returns.
        evaluator_episode_returns: list of all the evaluator episode returns.
    """
    episode_returns = [] # List to store history of episode returns.
    evaluator_episode_returns = [] # List to store history of evaluator returns.
    for episode in range(num_episodes):

        # Reset environment.
        obs = env.reset()
        episode_return = 0
        done = False

        while not done:

            # Agent select action.
            action, agent_actor_state = agent_select_action_func(
                                            next(rng), 
                                            agent_params, 
                                            agent_actor_state, 
                                            np.array(obs)
                                        )

            # Step environment.
            next_obs, reward, done, _ = env.step(int(action))

            # Pack into transition.
            transition = Transition(obs, action, reward, next_obs, done)

            # Add transition to memory.
            if agent_memory: # check if agent has memory
              agent_memory.push(transition)

            # Add reward to episode return.
            episode_return += reward

            # Set obs to next obs before next environment step. CRITICAL!!!
            obs = next_obs

        episode_returns.append(episode_return)

        # At the end of every episode we do a learn step.
        if agent_memory and agent_memory.is_ready(): # Make sure memory is ready

            for _ in range(learner_steps_per_episode):
                # First sample memory and then pass the result to the learn function
                memory = agent_memory.sample()
                agent_params, agent_learner_state = agent_learn_func(
                                                        next(rng), 
                                                        agent_params, 
                                                        agent_learner_state, 
                                                        memory
                                                    )

        if (episode % evaluator_period) == 0: # Do evaluation

            evaluator_episode_return = 0
            for eval_episode in range(evaluation_episodes):
                obs = env.reset()
                done = False
                while not done:
                    action, _ = agent_select_action_func(
                                    next(rng), 
                                    agent_params, 
                                    agent_actor_state, 
                                    np.array(obs), 
                                    evaluation=True
                                )

                    obs, reward, done, _ = env.step(int(action))

                    evaluator_episode_return += reward

            evaluator_episode_return /= evaluation_episodes

            evaluator_episode_returns.append(evaluator_episode_return)

            logs = [
                    f"Episode: {episode}",
                    f"Episode Return: {episode_return}",
                    f"Average Episode Return: {np.mean(episode_returns[-20:])}",
                    f"Evaluator Episode Return: {evaluator_episode_return}"
            ]

            print(*logs, sep="\t") # Print the logs

    return episode_returns, evaluator_episode_returns

Before we can test this environment loop function with our policy function we implemented earlier, we will need to modify it so that its interface matches the way our environment loop expects it. The `select_action` function should take a random seed in the first argument position, then parameters, then the actors internal state (more on this later), then the observation and finally a `evaluation` boolean flag that indicates if the function is being called during the evaluation loop or not (more on this later). The function should return the chosen action and the next state of the actor.

In [None]:
def select_action(key, params, actor_state, obs, evaluation=False):
    """
    This function assums params is a vecotor of same size as obs.
    It computes the dot product between params and obs. If the result is
    less than zero it returns actions zero. If the result is greater than
    or equal to zero, it returns action 1.
    """
    dot_product = np.dot(params, obs)
    
    if dot_product >= 0:
      action = 1
    else:
      action = 0

    return action, actor_state

**Exercise 6:** Can you convert this into a function that only uses Jax methods so that we can jit the function?


**Useful functions:** 
*   [Jax dot product](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.dot.html)
*   [Jax select](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html#jax.lax.select)

**Note:**

We already imported `jax.numpy` as `jnp`.



In [None]:
def select_action(key, params, actor_state, obs, evaluation=False):

  # YOUR CODE

  # END YOUR CODE

  return action, actor_state

# TEST: the result of your function should be action 1 in this test
key = None # not used
actor_state = None # not used
select_action_jit = jax.jit(select_action) # jit the function
params = np.array([1,1,-1,-1], "float32")
obs = np.array([1,1,1,1], "float32")

action, actor_state = select_action_jit(key, params, actor_state, obs)
print("Action:", action)

In [None]:
#@title Exercise 6 solution

def select_action(key, params, actor_state, obs, evaluation=False):
    dot_product = jnp.dot(params, obs)

    action = jax.lax.select(
        dot_product >= 0.0,
        1,
        0,
    )

    return action, actor_state

# TEST: the result of your function should be action 1 in this test
key = None # not used
actor_state = None # not used
select_action_jit = jax.jit(select_action) # jit the function

params = np.array([1,1,-1,-1], "float32")
obs = np.array([1,1,1,1], "float32")

action, actor_state = select_action_jit(key, params, actor_state, obs)
print("Action:", action)

We can now test our `select_action` function in the `run_environment_loop` function.

In [None]:
# Jax random number generator
rng = hk.PRNGSequence(jax.random.PRNGKey(0)) # don't worry about this for now

# Some arbitrary parameters
params = np.array([1,1,-1,-1], "float32")

episode_returns, evaluator_returns = run_environment_loop(
                                        rng, 
                                        env, 
                                        params,
                                        select_action_jit,
                                        num_episodes=1001,
                                      )

print("Average episode returns:", np.mean(episode_returns))

### Agent memory

In many RL algorithms the agent uses a kind of memory to store some of the experiences it had in the environment. The interface we will use for the agent's memory is very simple. It will have a function `memory.push(<transition>)` that adds some information about the transition to the memory, a function `memory.is_ready()` to check if the memory is ready to do some learning, and fianlly a function `memory.sample()` that returns some information that the agent learn function can use to do learning.

#### Average Episode Return Memory
We have built a simple agent memory module for you below. It stores the `epsisode_returns` of the last 20 episodes. Read through our implementation below and see if you can understand it. The `memory.sample()` method returns the average episode return over the last 20 episodes.

In [None]:
# A NamedTuple to store the average episode return of the last 20 runs
AverageEpisodeReturnMemory = collections.namedtuple(
                                "AverageEpisodeReturnMemory", 
                                ["average_episode_return"]
                            )

class AverageEpisodeReturnBuffer:

    def __init__(self, num_episodes_to_store=20):
        """
        This class implements an agent memory that stores the average episode 
        return over the last 20 episodes.
        """
        self.num_episodes_to_store = num_episodes_to_store
        self.episode_return_buffer = []
        self.current_episode_return = 0

    def push(self, transition):
        self.current_episode_return += transition.reward

        if transition.done: # If the episode is done
            # Add episode return to buffer
            self.episode_return_buffer.append(self.current_episode_return)

            # Reset episode return
            self.current_episode_return = 0


    def is_ready(self):
        return len(self.episode_return_buffer) == self.num_episodes_to_store

    def sample(self):
        average_episode_return = np.mean(self.episode_return_buffer)

        # Clear episode return buffer
        self.episode_return_buffer = []

        return AverageEpisodeReturnMemory(average_episode_return)

## Section 2: Random Policy Search (RPS)
Now we are ready to implement the random policy search algorithm which will randomly try different policies and keep track of the best policy it has found so far. We will say that policy $A$ is better than policy $B$ if the average episode return policy $A$ achieved over the last 20 episodes is greater than that of policy $B$. We will need to modify the way we store the agents parameters so that we always have access to the latest parameters as well as the best parameters.

In [None]:
# Parameter container for random policy search
RandomPolicySearchParams = collections.namedtuple("RandomPolicySearchParams", ["current", "best"])

# TEST: store two different sets of parameters
current_params = np.array([1,1,-1,-1])
best_params = np.array([0,0,0,0])
rps_params = RandomPolicySearchParams(current_params, best_params)

print(f"Best params: {rps_params.best}")
print(f"Current params: {rps_params.current}")

### RPS select action function
Now lets once again modify our `select_action` function such that it uses the best parameters when when `evaluation==True` and uses the current parameters when `evaluation==False`.

**Exercise 7:** Implement the `random_policy_search_select_action` function as described above. Make sure you use Jax so that we can jit the function.

In [None]:
def random_policy_search_select_action(
    key, 
    params, 
    actor_state, 
    obs, 
    evaluation=False
):

  # YOUR CODE

  # HINT: best_action = ... (two lines)

  # HINT: current_action = ... (two lines)

  # HINT: action = best_action if evaluation else current action (one line)

  # END YOUR CODE

  return action, actor_state

# TEST:
key = None # not used
actor_state = None # not used
random_policy_search_select_action_jit = jax.jit(
                                            random_policy_search_select_action
                                            ) # jit the function
# Parameters
current_params = np.array([-1,-1,-1,-1])
best_params = np.array([0,0,0,0])
rps_params = RandomPolicySearchParams(current_params, best_params)

# Observation
obs = np.array([1,1,1,1], "float32")

current_action, actor_state = random_policy_search_select_action_jit(
    key, 
    rps_params, 
    actor_state, 
    obs, 
    evaluation=False
)

best_action, actor_state = random_policy_search_select_action_jit(
    key, 
    rps_params, 
    actor_state, 
    obs, 
    evaluation=True
)

print("Current action:", current_action)
print("Best action:", best_action)

In [None]:
#@title Exercise 7 solution

def random_policy_search_select_action(
    key, 
    params, 
    actor_state, 
    obs, 
    evaluation=False
):

  # YOUR CODE

  dot_product = jnp.dot(params.best, obs)
  best_action = jax.lax.select(
      dot_product >= 0,
      1,
      0
  )

  dot_product = jnp.dot(params.current, obs)
  current_action = jax.lax.select(
      dot_product >= 0,
      1,
      0
  )

  action = jax.lax.select(
      evaluation,
      best_action,
      current_action
  )

  # END YOUR CODE

  return action, actor_state

# TEST:
key = None # not used
actor_state = None # not used
random_policy_search_select_action_jit = jax.jit(
                                            random_policy_search_select_action
                                            ) # jit the function
# Parameters
current_params = np.array([-1,-1,-1,-1])
best_params = np.array([0,0,0,0])
rps_params = RandomPolicySearchParams(current_params, best_params)

# Observation
obs = np.array([1,1,1,1], "float32")

current_action, actor_state = random_policy_search_select_action_jit(
    key, 
    rps_params, 
    actor_state, 
    obs, 
    evaluation=False
)

best_action, actor_state = random_policy_search_select_action_jit(
    key, 
    rps_params, 
    actor_state, 
    obs, 
    evaluation=True
)

print("Current action:", current_action)
print("Best action:", best_action)


### RPS learn function
Now we need to implement a learn function for our random policy search agent. The learn function is quite simple. All we need to do is check if the current weights are better than the best weights. If they are better, then set the current weights to be the new best weights and randomly sample a new set of current weights. Lets assume our learn function receives a memory from the AverageEpisodeReturnMemory we implmented earlier. We can uses this to compare the current weights to the best weights. We will need to keep track of the best average episode return for the learn function. For that we can use the `learn_state` argument.

In [None]:
# A NamedTuple to store the best average episode return so far
LearnerState = collections.namedtuple(
                                      "LearnerState", 
                                      ["count", "best_average_episode_return"]
                                      )

**Exercise 8:** Write a function to randomly sample new weights using jax. The weights should be sampled from the interval [-2,2].

**Useful functions:** 
*   [Jax random uniform sample](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.uniform.html#jax.random.uniform)

In [None]:
def get_new_random_weights(random_key, old_weights):
    new_weights_shape = old_weights.shape
    new_weights_dtype = old_weights.dtype
    new_weights = None # you should overwrite this

    # YOUR CODE
    
    # new_weights = ...

    # END YOUR CODE

    return new_weights

# TEST 
old_weights = np.array([1,1,1,1], "float32")
get_new_random_weights_jit = jax.jit(get_new_random_weights) # jit the function
random_key = next(rng) # get net random key from the random number generator
new_weights = get_new_random_weights_jit(random_key, old_weights)
print("New weights:", new_weights)


In [None]:
#@title Exercise 8 solution

def get_new_random_weights(random_key, old_weights):
    new_weights_shape = old_weights.shape
    new_weights_dtype = old_weights.dtype
    new_weights = None # you should overwrite this

    # YOUR CODE
    
    new_weights = jax.random.uniform(
                      random_key, 
                      new_weights_shape, 
                      new_weights_dtype,
                      minval=-2.0,
                      maxval=2.0
                  )

    # END YOUR CODE

    return new_weights

# TEST 
old_weights = np.array([1,1,1,1], "float32")
get_new_random_weights_jit = jax.jit(get_new_random_weights) # jit the function
random_key = next(rng) # get net random key from the random number generator
new_weights = get_new_random_weights(random_key, old_weights)
print("New weights:", new_weights)



Now lets implement the random policy search function.

**Exercise 9:** Use the description of the random policy search learn function at the top of this section to complete the function below. Try to use jax (remember `jax.lax.select()`).

In [None]:
def random_policy_search_learn(key, params, learner_state, memory):
    best_weights = params.best 
    current_weights = params.current

    average_episode_return = memory.average_episode_return
    best_average_episode_return = learner_state.best_average_episode_return


    # YOUR CODE

    # HINT: if current better than best then ...

    # END YOUR CODE

    # Generate new random weights
    new_weights = get_new_random_weights_jit(key, best_weights)

    # Bundle weights in RPS Params NamedTuple
    params = RandomPolicySearchParams(current=new_weights, best=best_weights)

    # Increment the learn counter by one
    learn_count = learner_state.count + 1

    return params, LearnerState(learn_count, best_average_episode_return)

In [None]:
#@title Exercise 9 solution
def random_policy_search_learn(key, params, learner_state, memory):

    best_weights = params.best 
    current_weights = params.current

    current_episode_return = memory.average_episode_return
    best_average_episode_return = learner_state.best_average_episode_return

    best_weights = jax.lax.select(
        current_episode_return > best_average_episode_return,
        current_weights,
        best_weights
    )
        
    best_average_episode_return = jax.lax.select(
        current_episode_return > best_average_episode_return,
        current_episode_return,
        best_average_episode_return
    )

    # Generate new random weights
    new_weights = get_new_random_weights_jit(key, best_weights)

    # Bundle weights in RPS Params NamedTuple
    params = RandomPolicySearchParams(current=new_weights, best=best_weights)

    # Increment the learn counter by one
    learn_count = learner_state.count + 1

    return params, LearnerState(learn_count, best_average_episode_return)
    

Now we can put everything together using the environment loop.

In [None]:
initial_learner_state = LearnerState(0, -float("inf"))

# Jit the learn function for some extra speed
random_policy_search_learn_jit = jax.jit(random_policy_search_learn)

initial_weights = np.array([1,1,1,1], "float32")
initial_params = RandomPolicySearchParams(initial_weights, initial_weights)

memory = AverageEpisodeReturnBuffer(num_episodes_to_store=20)

episode_return, evaluator_episode_returns = run_environment_loop(
                                        rng, 
                                        env, 
                                        initial_params, 
                                        random_policy_search_select_action_jit, 
                                        None, # no actor state
                                        random_policy_search_learn_jit, 
                                        initial_learner_state, 
                                        memory, 
                                        num_episodes=2001
                                    )

# Plot graph of evaluator episode returns
plt.plot(evaluator_episode_returns)
plt.title("Random Search Evaluator Episode Return")
plt.show()



Hopefully you found a set of optimal parameters on CartPole. You will know you have, if the evaluator episode return eventually reaches 500. If you didn't find optimal parameters, try running the environment loop again, you were probably just a little unlucky. That is the big limitation with random policy search after all, if you are unlucky you might not (randomly) stumble on the optimal policy.

So, in random policy search there is very little (if any) real learning going on. Next, lets look to implementing a simple RL algorithm instead, that can use its experience to guide the learning process, rather than just randomly sampling new weights. Hopefully this will help us more reliably find an optimal policy.

## Section 3: Policy Gradients (PG)
As discussed, the goal in RL is to find a policy which maximise the expected cummulative reward (return) the agent receives from the environment. We can write the expected return of a policy as:

$J(\pi_\theta)=\mathrm{E}_{\tau\sim\pi_\theta}\ [R(\tau)]$,

where $\pi_\theta$ is a policy parametrised by $\theta$, $\mathrm{E}$ means *expectation*, $\tau$ is shorthand for "*episode*", $\tau\sim\pi_\theta$ is shorthand for "*episodes sampled using the policy* $\pi_\theta$", and $R(\tau)$ is the return of episode $\tau$.

Then, the goal in RL is to find the parameters $\theta$ that maximise the function $J(\pi_\theta)$. One way to find these parameters is to perform gradient ascent on $J(\pi_\theta)$ with respect to the parameters $\theta$: 

$\theta_{k+1}=\theta_k + \alpha \nabla J(\pi_\theta)|_{\theta_{k}}$,

where $\nabla J(\pi_\theta)|_{\theta_{k}}$ is the gradient of the expected return with respect to the policy parameters $\theta_k$ and $\alpha$ is the step size. This quantity, $\nabla J(\pi_\theta)$, is also called the **policy gradient** and is very important in RL. If we can comput the policy gradient, theat we will have a means by which to directly optimise our policy.

As it turns out, there is a way for us to compute the policy gradient and the mathematical derivation can be found [here](https://spinningup.openai.com/en/latest/spinningup/rl_intro3.html). But for this tutorial we will ommit the derivation and just give you the result:


$\nabla_{\theta} J(\pi_{\theta})=\underset{\tau \sim \pi_{\theta}}{\mathrm{E}}[\sum_{t=0}^{T} \nabla_{\theta} \log \pi_{\theta}(a_{t} \mid s_{t}) R(\tau)]$

Informaly, the policy gradient is equal to the gradient of the log of the probability of the action chosen multiplied by the return of the episode in which the action was taken.


### REINFORCE
REINFORCE is a simple RL algorithm that uses the policy gradient to find the optimal policy by increasing the probability of choosing actions that tend to lead to high return episodes.

**Exercise 10:** Implement a function that takes the probability of an action and the return of the episode the action was taken in and returns the log of the probability multiplied by the return. Make sure you use jax so that we can jit the function.

**Usefull functions:**
*   [Jax numpy log](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.log.html)

In [None]:
def compute_weighted_log_prob(action_prob, episode_return):
    weighted_log_prob = None # you will need to overwrite this

    # YOUR CODE


    # END YOUR CODE

    return weighted_log_prob

# TEST: the result should be -22.314354
action_prob = 0.8
episode_return = 100
print("Weighted log prob:", compute_weighted_log_prob(action_prob, episode_return))

In [None]:
#@title Exercise 10 solution

def compute_weighted_log_prob(action_prob, episode_return):
    weighted_log_prob = None # you will need to overwrite this

    # YOUR CODE

    weighted_log_prob = jnp.log(action_prob) * episode_return

    # END YOUR CODE

    return weighted_log_prob

# TEST
action_prob = 0.8
episode_return = 100
print("Weighted log prob:", compute_weighted_log_prob(action_prob, episode_return))


### Rewards-to-go
The gradient of the log of the actions probability, weighted by the return of the episode will tend to push up the probability of actions that were in episodes whith high return, regardless of where in the episode the action was taken. This does not really make much sense because an action near the end of an episode may be reinforced because lots of reward was collected earlier on in the episode, before the action was taken. RL agents should really only reinforce actions on the basis of their *consequences*. Rewards obtained before taking an action have no bearing on how good that action was: only rewards that come after. The cummulative rewards received after an action was taken is called the **rewards-to-go** and can be computed as:

$\hat{R}_t=\sum_t^Tr_t$

Compare the rewards-to-go with the episode return:

$R(\tau)=\sum_{t=0}^Tr_t$

Thus, the policy gradient with rewards-to-go is given by:

$\nabla_{\theta} J(\pi_{\theta})=\underset{\tau \sim \pi_{\theta}}{\mathrm{E}}[\sum_{t=0}^{T} \nabla_{\theta} \log \pi_{\theta}(a_{t} \mid s_{t}) \hat{R}_t]$

**Exercise 11:** Implement a function that takes a list of all the rewards obtained in an episode and computes the rewards-to-go. Don't worry about using jax in this function. You can use regular Python operations like `for-loops`.

In [None]:
# Implement reward to go
def rewards_to_go(rewards):
    """
    This function should take a list of rewards as input and 
    compute the rewards-to-go for each timestep.
    
    Arguments:
        rewards[t] is the reward at time step t.

    Returns:
        rewards_to_go[t] should be the reward-to-go at timestep t.
    """

    rewards_to_go = []

    # YOUR CODE


    # END YOUR CODE

    return rewards_to_go

# TEST: The result should be [10, 9, 7, 4]
rewards = np.array([1,2,3,4])
print("Rewards-to-go:", rewards_to_go(rewards))

In [None]:
#@title Exercise 11 solution

def rewards_to_go(rewards):
    rewards_to_go = []
    for i in range(len(rewards)):
        r2g = 0
        for j in range(i, len(rewards)):
            r2g += rewards[j]
        rewards_to_go.append(r2g)
    return rewards_to_go

# TEST: The result should be [10, 9, 7, 4]
rewards = np.array([1,2,3,4])
print("Rewards-to-go:", rewards_to_go(rewards))

In [None]:
# Faster rewards to go calculation using numpy
def rewards_to_go(rewards):
    return np.flip(np.cumsum(np.flip(rewards)))

# TEST: The result should be [10, 9, 7, 4]
rewards = np.array([1,2,3,4])
print("Rewards-to-go:", rewards_to_go(rewards))

Next we will need to make a new agent memory to store the rewards-to-go $R_t$ along with the observation $o_t$ and action $a_t$ at every timestep.

In [None]:
# Now we need a new episode memory buffer

EpisodeRewardsToGoMemory = collections.namedtuple("AverageEpisodeReturnMemory", ["obs", "action", "reward_to_go"])

class EpisodeRewardsToGoBuffer:

    def __init__(self, num_transitions_to_store=500, batch_size=500):
        self.batch_size = batch_size
        self.memory_buffer = collections.deque(maxlen=num_transitions_to_store)
        self.current_episode_transition_buffer = []

    def push(self, transition):
        self.current_episode_transition_buffer.append(transition)

        if transition.done:

            episode_rewards = []
            for t in self.current_episode_transition_buffer:
                episode_rewards.append(t.reward)

            r2g = rewards_to_go(episode_rewards)

            for i, t in enumerate(self.current_episode_transition_buffer):
                memory = EpisodeRewardsToGoMemory(t.obs, t.action, r2g[i])
                self.memory_buffer.append(memory)

            # Reset episode buffer
            self.current_episode_transition_buffer = []


    def is_ready(self):
        return len(self.memory_buffer) >= self.batch_size

    def sample(self):
        random_memory_sample = random.sample(self.memory_buffer, self.batch_size)

        obs_batch, action_batch, reward_to_go_batch = zip(*random_memory_sample)

        return EpisodeRewardsToGoMemory(
            np.stack(obs_batch).astype("float32"), 
            np.asarray(action_batch).astype("int32"), 
            np.asarray(reward_to_go_batch).astype("int32")
        )


# Instantiate Memory
REINFORCE_memory = EpisodeRewardsToGoBuffer(num_transitions_to_store=512, batch_size=256)

### Policy neural network
Next, we need to aproximate the policy using a simple neural network. Our policy neural network will have an input layer that takes the observation as input and passes it through two hidden layers and then outputs one scalar value for each of the possible actions. So, in CartPole the output layer will have size `2`.

[Haiku](https://github.com/deepmind/dm-haiku) is a library for implementing neural networks is Jax. Below we have implemented a simple function to make the policy network for you. 


In [None]:
def make_policy_network(num_actions: int, layers=[10, 10]) -> hk.Transformed:
  """Factory for a simple MLP network for the policy."""

  def policy_network(obs):
    network = hk.Sequential(
        [
            hk.Flatten(),
            hk.nets.MLP(layers + [num_actions])
        ]
    )
    return network(obs)

  return hk.without_apply_rng(hk.transform(policy_network))

Haiku networks have two important functions you need to know about. The first is the `<network>.init(<rng>, <input>)`, which returns a set of random initial parameters. The second method is the `<network>.apply(<params>, <input>)` which passes an input through the network using the set of parameters provided.

In [None]:
# Example
POLICY_NETWORK = make_policy_network(num_actions=2, layers=[20,20])
random_key = next(rng) # get next random key
dummy_obs = np.array([1,1,1,1], "float32")

REINFORCE_params = POLICY_NETWORK.init(random_key, dummy_obs)
print("Initial params:", REINFORCE_params.keys())

output = POLICY_NETWORK.apply(REINFORCE_params, dummy_obs)
print("Policy network output:", output)


The outputs of our policy network are [logits](https://qr.ae/pv4YTe). To convert this into a probability distribution over actions we pass the logits to the [softmax](https://en.wikipedia.org/wiki/Softmax_function) function (more on this later).

### Action selection

**Exercise 12:** Complete the function below which takes a vector of logits and randomly samples an action. 

**Useful functions:**
*   [Jax random categorical](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.categorical.html)

In [None]:
def sample_action(random_key, logits):
    action = None # you will need to overwrite this
    
    # YOUR CODE HERE


    # END YOUR code

    return action

# TEST
for i in range(10):
  print("Action:", sample_action(next(rng), np.array([0, 1], "float32")))

In [None]:
#@title Exercise 12 solution

def sample_action(random_key, logits):
    action = None # you will need to overwrite this
    
    # YOUR CODE HERE
    action = jax.random.categorical(random_key, logits)
    # END YOUR code

    return action

# TEST: 
for i in range(10):
  print("Action:", sample_action(next(rng), np.array([0, 1], "float32")))


Notice in the tests that the actions are randomly sampled. Ofcourse the action with the higher probability will be chose more often because, but there is always a chance the action with the lower probability will be chosen. This is actually desirable in RL because it mean the agent will always try new things in the environment. We call this **exploring**. Exploring is important because it helps the agent discover new, possibly better, strategies in the environment. When an agent chooses the best possible action (given its current knowledge) we say the agent is being **greedy**.

**Exercise 13:** Complete the function below which takes a vector of logits and returns the greedy action. 

**Useful functions:**
*   [Jax numpy argmax](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.argmax.html)


In [None]:
def greedy_action(logits):
    action = None # you will need to overwrite this

    # YOUR CODE

    # END YOUR CODE

    return action

# TEST
for i in range(10):
    print("Action:", greedy_action(np.array([0, 1], "float32")))

In [None]:
#@title Exercise 13 solution

def greedy_action(logits):
    action = None # you will need to overwrite this

    # YOUR CODE
    action = jnp.argmax(logits)
    # END YOUR CODE

    return action

# TEST
for i in range(10):
    print("Action:", greedy_action(np.array([0, 1], "float32")))


Notice that the greed action selector always chooses the same action, namely the one with the highest probability (or equivalently the largets logit). Next, we will implement the REINFORCE `select_action` function. The function passes the observation through the policy neural network to get loggits and then uses thos to chose an action. If `evaluation` is `True`, the greedy action will be chosen. Otherwise, the action is sampled from the action probability distribution given by the logits.


In [None]:
def REINFORCE_select_action(key, params, actor_state, obs, evaluation=False):
    obs = jnp.expand_dims(obs, axis=0) # add dummy batch dim
    logits = POLICY_NETWORK.apply(params, obs)[0] # remove batch dim

    sampled_action = sample_action(key, logits)

    best_action = greedy_action(logits)

    action = jax.lax.select(
        evaluation,
        best_action,
        sampled_action
    )
    
    return action, actor_state

# TEST
action, actor_state = REINFORCE_select_action(
    key=next(rng),
    params=REINFORCE_params, # we instantiated this earlier
    actor_state=None, # not used
    obs=np.array([1,1,1,1], "float32") # dummy obs
)

print("Action:", action)

Now that we finished the REINFORCE action selection function, all we have left to do is make a REINFORCE learn function. The learn function should use the `weighted_log_prob` function we made earlier to compute the policy gradient and apply the updates to our neural network.

### Network Optimiser

To apply updates to our neural network we will use a Jax library called [Optax](https://github.com/deepmind/optax). Optax has an implementation of the [Adam optimizer](https://www.geeksforgeeks.org/intuition-of-adam-optimizer/) which we can use.

In [None]:
REINFORCE_OPTIMIZER = optax.adam(1e-3)
REINFORCE_optim_state = REINFORCE_OPTIMIZER.init(REINFORCE_params)

### Policy gradient loss

**Exercise 14:** Complete the `pg_loss` function below.

**Useful methods:**
*   [Jax softmax](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.softmax.html)
*   [Jax one-hot vector](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.one_hot.html)
*   [Jax dot product](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.dot.html)

In [None]:
def pg_loss(action, logits, reward_to_go):
    chosen_action_prob = None # you will need to overwrite this

    # YOUR CODE

    # HINT all_action_probs = ... (convert logits into probs)

    # HINT extract the prob of the desired action.
    # One way to achieve this is to use a one-hot vector and a dot product...?

    # END YOUR CODE
    weighted_log_prob = compute_weighted_log_prob(
                            chosen_action_prob, 
                            reward_to_go
                        )
    
    loss = - weighted_log_prob # negative because we want gradient `ascent`
    
    return loss

# TEST 
print("Policy gradient loss:", pg_loss(0, np.array([0,1], "float32"), 10))

In [None]:
#@title Exercise 14 solution

def pg_loss(action, logits, reward_to_go):

    # YOUR CODE
    
    all_action_probs = jax.nn.softmax(logits)
    action_mask = jax.nn.one_hot(action, logits.shape[0])
    chosen_action_prob = jnp.dot(all_action_probs, action_mask)

    # END YOUR CODE
    
    weighted_log_prob = compute_weighted_log_prob(
                            chosen_action_prob, 
                            reward_to_go
                        )
    
    loss = - weighted_log_prob 
    
    return loss

# TEST 
print("Policy gradient loss:", pg_loss(0, np.array([0,1], "float32"), 10))

Now, when we do a policy gradient update step we are going to want to do it using a batch of experience, rather than just a single experience like above. We can use Jax's [vmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap) function to easily make our `pg_loss` function work on a batch of experience.

In [None]:
def batched_pg_loss(params, obs_batch, action_batch, reward_to_go_batch):
    logits_batch = POLICY_NETWORK.apply(params, obs_batch) # network we made earlier
    pg_loss_batch = jax.vmap(pg_loss)(action_batch, logits_batch, reward_to_go_batch) # add batch
    mean_pg_loss = jnp.mean(pg_loss_batch)
    return mean_pg_loss

# TEST
obs_batch = np.array([[1,0,0,1],[1,0,0,1],[1,0,0,1]])
actions_batch = np.array([1,0,0])
rew2go_batch = np.array([2.3, 4.3, 2.1])

loss = batched_pg_loss(REINFORCE_params, obs_batch, actions_batch, rew2go_batch)

print("PG loss on batch:", loss)

Now we can make the REINFORCE learn function.

In [None]:
REINFORCELearnerState = collections.namedtuple("LearnerState", ["optim_state"])

def REINFORCE_learn(key, params, learner_state, memory):
    
    #Get Policy gradient by using `jax.grad()` on `batched_pg_loss`
    grad_loss = jax.grad(batched_pg_loss)(params, memory.obs, memory.action, memory.reward_to_go)

    # Get param updates using gradient and optimizer
    updates, new_optim_state = REINFORCE_OPTIMIZER.update(grad_loss, learner_state.optim_state)

    # Apply updates to params
    params = optax.apply_updates(params, updates)

    return params, REINFORCELearnerState(new_optim_state) # update learner state

# Lets jit the learn function for some extra speed
REINFORCE_learn_jit = jax.jit(REINFORCE_learn)

Now we can train our REINFORCE agent by putting everything together using the environment loop. 

In [None]:
learner_state = REINFORCELearnerState(REINFORCE_optim_state)
actor_state = None # not used

episode_returns, evaluator_returns = run_environment_loop(rng, env, REINFORCE_params, jax.jit(REINFORCE_select_action), actor_state, 
    REINFORCE_learn_jit, learner_state, REINFORCE_memory, num_episodes=1001)

# Plot the episode returns over time
plt.plot(episode_returns)
plt.show()


## Section 4: Q-Learning
Another common aproach to finding an optimal policy in an environment in RL is via Q-learning. 

### Value functions
In Q-learnig the agent learns a function that aproximates the **value** of states and actions. By *value* we mean return if you start in a particular state $s_t$, take a particular action $a_t$, and then act according to a particular policy $\pi$ forever after. There are two types of **value functions**:
1. The **state value function** which returns the expected value starting from a particular state.

  $V_\pi(s)=\mathrm{E}_{\tau\sim\pi}\left[R(\tau) \mid s_0=s\right]$

2. The **state-action value function** which returns the expected value of starting from a particular state and choosing a particular action.

  $Q_\pi(s,a)=\mathrm{E}_{\tau\sim\pi}\left[R(\tau) \mid s_0=s,\ a_t=a\right]$

We say that the value function $V_\pi(s)$ or $Q_\pi(s,a)$ is the **optimal** value function if the policy $\pi$ is an optimal policy. We denote the optimal value functions as follows:

1.   $V_\ast(s)=max_\pi \ \mathrm{E}_{\tau\sim\pi}\left[R(\tau) \mid s_0=s\right]$
2.   $Q_\ast(s)=max_\pi \  \mathrm{E}_{\tau\sim\pi}\left[R(\tau) \mid s_0=s,\ a_t=a\right]$

There is an important relationship between the optimal action $a_\ast$ in a state $s$ and the optimal state-action value function $Q_\ast$. Namely, the optimal action $a_\ast$ in state $s$ is equal to the action that maximises the optimal state-action value function. This relationship naturally induces an optimal policy:

$\pi_\ast(s)=\mathop{argmax}_a\ Q_\ast(s, a)$



### Greedy action selection

**Exercise 15:** Lets implement a function that, given a vector of Q-values, returns the action with the largets Q-value (i.e. the greedy action).

**Useful methods:**
*   [Jax argmax](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.argmax.html)

In [None]:
# Implement a function takes q-values as input and returns the greedy_action
def select_greedy_action(q_values):
    action = None # you will need to overwrite this

    # YOUR CODE


    # END YOUR CODE

    return action

# TEST
q_values = np.array([1, 3])
print("Greedy action:", select_greedy_action(q_values))

In [None]:
#@title Exercise 15 solution


def select_greedy_action(q_values):
    action = None # you will need to overwrite this
    
    # YOUR CODE

    action = jnp.argmax(q_values)

    # END YOUR CODE

    return action

# TEST
q_values = np.array([1, 3])
print("Greedy action:", select_greedy_action(q_values))

### Random exploration

**Exercise 16:** Now implement a function that, given the number of possible actions, returns a random action.

**Useful methods:**

*   [Jax random int](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.randint.html)

In [None]:
def select_random_action(key, num_actions):
    action = None

    # YOUR CODE


    # END YOUR CODE

    return action

# TEST
for i in range(10):
  print(f"Random action number {i}: {select_random_action(next(rng), 2)}")

In [None]:
#@title Exercise 16 solution

def select_random_action(key, num_actions):
    action = None

    # YOUR CODE

    action = jax.random.randint(
        key, 
        shape=(1,), 
        minval=0, 
        maxval=num_actions
    )[0] # important to take zeroth element

    # END YOUR CODE

    return action

# TEST
for i in range(10):
  print(f"Random action number {i}: {select_random_action(next(rng), 2)}")

**Exercise 17:** Implement a function that takes the number of timesteps as input and returns the current epsilon value.

In [None]:
#@title Exercise 17 solution

EPSILON_DECAY_TIMESTEPS = 3000
EPSILON_MIN = 0.05 # 5%

def get_epsilon(num_timesteps):
    epsilon = 1.0 - num_timesteps / EPSILON_DECAY_TIMESTEPS

    epsilon = jax.lax.select(
        epsilon < EPSILON_MIN,
        EPSILON_MIN,
        epsilon
    )

    return epsilon

# TEST

print("Epsilon after 10 timesteps:", get_epsilon(10))
print("Epsilon after 10 010 timesteps:", get_epsilon(10_010))


### Epsilon-greedy action selection

**Exercise 18:** Now lets put these functions together to do epsilon-greedy action selection.

In [None]:
def select_epsilon_greedy_action(key, q_values, num_timesteps):

    # YOUR CODE HERE


    # END YOUR CODE

    return action

# TEST
dummy_q_values = jnp.array([0,1], jnp.float32)
num_timesteps = 1010 # very greedy
print("Greedy actions:", end=" ")
for i in range(10):
    print(select_epsilon_greedy_action(next(rng), dummy_q_values, num_timesteps), end=" ")
print()

num_timesteps = 0 # completly random
print("Random actions:", end=" ")
for i in range(10):
    print(select_epsilon_greedy_action(next(rng), dummy_q_values, num_timesteps), end=" ")

In [None]:
#@title Exercise 18 solution

# Now make a function that takes an epsilon-greedy action

def select_epsilon_greedy_action(key, q_values, num_timesteps):

    epsilon = get_epsilon(num_timesteps)

    should_explore = jax.random.uniform(key, (1,))[0] < epsilon

    num_actions = len(q_values)

    action = jax.lax.select(
        should_explore,
        select_random_action(key, num_actions), 
        select_greedy_action(q_values)
    )

    return action

# TEST
dummy_q_values = jnp.array([0,1], jnp.float32)
num_timesteps = 1010 # very greedy
print("Greedy actions:", end=" ")
for i in range(10):
    print(select_epsilon_greedy_action(next(rng), dummy_q_values, num_timesteps), end=" ")
print()

num_timesteps = 0 # completly random
print("Random actions:", end=" ")
for i in range(10):
    print(select_epsilon_greedy_action(next(rng), dummy_q_values, num_timesteps), end=" ")

### Q-Network
Next we use Haiku to make a simple neural network to aproximate the Q-function.

In [None]:
def build_network(num_actions: int, layers=[10, 10]) -> hk.Transformed:
  """Factory for a simple MLP network for approximating Q-values."""

  def q_network(obs):
    network = hk.Sequential(
        [hk.Flatten(),
         hk.nets.MLP(layers + [num_actions])])
    return network(obs)

  return hk.without_apply_rng(hk.transform(q_network))

In [None]:
# Initialise Q-network
Q_NETWORK = build_network(num_actions=2)
dummy_obs = jnp.zeros((1,4), jnp.float32)
Q_NETWORK_WEIGHTS = Q_NETWORK.init(next(rng), dummy_obs)
print("Q-Learning params:", Q_NETWORK_WEIGHTS.keys())

### DQN action selection

We now have everything we need to make the `DQN_select_action function`. We will use the `actor_state` to store a counter which keeps track of the current number of timesteps.

In [None]:
# Actor state stores the current number of timesteps
ActorState = collections.namedtuple("ActorState", ["count"])

def dqn_select_action(key, params, actor_state, obs, evaluation=False):
    obs = jnp.expand_dims(obs, axis=0) # add dummy batch dim
    q_values = Q_NETWORK.apply(params.online, obs)[0] # remove batch dim

    action = select_epsilon_greedy_action(key, q_values, actor_state.count)
    greedy_action = select_greedy_action(q_values)

    action = jax.lax.select(
        evaluation,
        greedy_action,
        action
    )

    next_actor_state = ActorState(actor_state.count + 1) # increment timestep counter

    return action, next_actor_state

### Transition Replay Buffer
For the DQN algorithm we will need an agent memory that stores entire transitions: `obs`, `action`, `reward`, `next_obs`, `done`. When we sample the memory, samples should be chosen from the memory randomly.

In [None]:
class TransitionMemory(object):
  """A simple Python replay buffer."""

  def __init__(self, max_size=5000, batch_size=256):
    self.batch_size = batch_size
    self.buffer = collections.deque(maxlen=max_size)

  def push(self, transition):

      self.buffer.append(
          (transition.obs, transition.action, transition.reward, 
           transition.next_obs, transition.done)
      )

  
  def is_ready(self):
    return self.batch_size <= len(self.buffer)

  def sample(self):
    random_replay_sample = random.sample(self.buffer, self.batch_size)
    obs_batch, action_batch, reward_batch, next_obs_batch, done_batch = zip(*random_replay_sample)

    return Transition(
        np.stack(obs_batch).astype("float32"), 
        np.asarray(action_batch).astype("int32"), 
        np.asarray(reward_batch).astype("float32"), 
        np.stack(next_obs_batch).astype("float32"), 
        np.asarray(done_batch).astype("float32")
    )

### Bellman Error

In [None]:
# Squared error
def compute_squared_error(pred, target):
    squared_error = jnp.square(pred - target)
    return squared_error

In [None]:
# Bellman target
def compute_bellman_target(reward, done, target_value):
    return reward + (1.0 - done) * target_value

### Q-learning loss

In [None]:
def q_learning_loss(q_values, action, reward, done, target_q_values):
    q_value = q_values[action]
    target_q_value = jnp.max(target_q_values)
    bellman_target = compute_bellman_target(reward, done, target_q_value)
    squared_error = compute_squared_error(q_value, bellman_target)
    return squared_error

In [None]:
def batched_q_learning_loss(online_params, obs, actions, rewards, next_obs, dones, target_params):
    q_values = Q_NETWORK.apply(online_params, obs)
    target_q_values = Q_NETWORK.apply(target_params, next_obs)
    squared_error = jax.vmap(q_learning_loss)(q_values, actions, rewards, dones, target_q_values)
    mean_squared_error = jnp.mean(squared_error)
    return mean_squared_error

In [None]:
LearnerState = collections.namedtuple("LearnerState", ["count", "optim_state"])

# Initialise Q-network optimizer
DQN_OPTIMIZER = optax.adam(1e-3) # learning rate = 0.001
DQN_optim_state = optim.init(Q_NETWORK_WEIGHTS)

### Target network update

In [None]:
DQNParams = collections.namedtuple("DQNParams", ["online", "target"])

TARGET_UPDATE_RATE = 50

def update_target(params, trainer_steps):

    target_params = jax.lax.cond(
        (trainer_steps % TARGET_UPDATE_RATE) == 0,
        lambda x: x.online,
        lambda x: x.target,
        params
    )

    return DQNParams(params.online, target_params)

# TEST
update_target_jit = jax.jit(update_target)
DQN_params = DQNParams(Q_NETWORK_WEIGHTS, Q_NETWORK_WEIGHTS)
print("Updated target params:", update_target_jit(DQN_params, 0).target.keys())


### DQN learn function

Next lets implement the dqn learner step function. 

In [None]:
def dqn_learner_step(rng, params, learner_state, memory):
    grad_loss = jax.grad(batched_q_learning_loss)(params.online, memory.obs, 
                                            memory.action, memory.reward, 
                                            memory.next_obs, memory.done,
                                            params.target)

    updates, opt_state = DQN_OPTIMIZER.update(grad_loss, learner_state.optim_state)

    online_params = optax.apply_updates(params.online, updates)

    # Maybe update target network params
    params = DQNParams(online_params, params.target)
    params = update_target(params, learner_state.count)

    # Increment learner step counter
    learner_state = LearnerState(learner_state.count + 1, opt_state)

    return params, learner_state


In [None]:
# DQN params
dqn_params = DQNParams(Q_NETWORK_WEIGHTS, Q_NETWORK_WEIGHTS)

# Initialise actor
actor_state = ActorState(0)
dqn_select_action_jit = jax.jit(dqn_select_action)

# Initialise learner
learner_state = LearnerState(0, DQN_optim_state)
dqn_learner_step_jit = jax.jit(dqn_learner_step)

# Initialise memory
memory = TransitionMemory(5000, 256) # store 5000 transitions

# Run environment loop
episode_returns, evaluator_returns = run_environment_loop(rng, env, dqn_params, 
                      dqn_select_action_jit, actor_state, 
                      dqn_learner_step_jit, learner_state, memory,
                      num_episodes=5_000,
                      learner_steps_per_episode=4)

plt.plot(evaluator_returns)
plt.show()

## Additional Resources

## Conclusion
**Summary:**

[Summary of the main points/takeaways from the prac.]

**Next Steps:** 

[Next steps for people who have completed the prac, like optional reading (e.g. blogs, papers, courses, youtube videos). This could also link to other pracs.]

**Appendix:** 

[Anything (probably math heavy stuff) we don't have space for in the main practical sections.]

**References:** 

[References for any content used in the notebook.]

For other practicals from the Deep Learning Indaba, please visit [here](https://github.com/deep-learning-indaba/indaba-pracs-2022).

Math foundations:


## Feedback

Please provide feedback that we can use to improve our practicals in the future.

In [None]:
#@title Generate Feedback Form. (Run Cell)
from IPython.display import HTML
HTML(
"""
<iframe 
	src="https://forms.gle/bvLLPX74LMGrFefo9",
  width="80%" 
	height="1200px" >
	Loading...
</iframe>
"""
)

<img src="https://baobab.deeplearningindaba.com/static/media/indaba-logo-dark.d5a6196d.png" width="50%" />