In [1]:
import jax
import jax.numpy as jnp
import jaxlib
import optax
import haiku as hk
import gym

  PyTreeDef = type(jax.tree_structure(None))


In [2]:
env = gym.make('LunarLander-v2', new_step_api=True)
print(env.action_space.n)

4


In [3]:
NUM_ACTIONS = env.action_space.n


## What we need

- [x] Environment 
- [] Memory Buffer
- [] DQN model
- [] loss function
- [] Training Loop


In [4]:
class TrainConfig:
    MEMORY_SIZE = 10000
    BATCH_SIZE = 32
    UPDATE_PARAMS_EVERY_N_STEPS = 4
    GAMMA = 0.995
    TAU = 0.001
    E_MIN = 0.01
    E_DECAY = 0.995
    N_EPISODES = 1000
    MAX_N_STEPS_PER_EPISODE = 1000
    

In [5]:
import dataclasses
from typing import NamedTuple
from collections import deque


@dataclasses.dataclass
class Experience:
    state: jnp.ndarray
    action: int
    reward: float
    next_state: jnp.ndarray
    done: bool


memory = deque(maxlen=TrainConfig.MEMORY_SIZE)

## What we need

- [x] Environment 
- [x] Memory Buffer
- [] DQN model
- [] loss function
- [] Training Loop


In [6]:
from typing import NamedTuple
# @dataclasses.dataclass
class TrainingState(NamedTuple):
    params: hk.Params
    target_params: hk.Params
    eval_params: hk.Params
    opt_state: optax.OptState

# @dataclasses.dataclass
class Batch(NamedTuple):
    states: jnp.ndarray
    actions: int
    rewards: float
    next_states: jnp.ndarray
    dones: bool



def network_fn(x: jnp.ndarray) -> jnp.ndarray:
    model = hk.Sequential(
        [
            hk.Linear(64),jax.nn.relu,
            hk.Linear(64), jax.nn.relu,
            hk.Linear(NUM_ACTIONS),
        ]

    )
    return model(x)

network = hk.without_apply_rng(hk.transform(network_fn))
target_network = hk.without_apply_rng(hk.transform(network_fn))
optimiser = optax.adam(1e-3)
# Initialise network and optimiser; note we draw an input to get shapes.

In [7]:
import jax.random as jrandom
import random
import numpy as np
keygen = jrandom.PRNGKey(0)

def get_random_batch(memory):
    batch = random.sample(memory, k=TrainConfig.BATCH_SIZE)
    return Batch(
        states=jnp.array([e.state for e in batch]),
        actions=jnp.array([e.action for e in batch]),
        rewards=jnp.array([e.reward for e in batch]),
        next_states=jnp.array([e.next_state for e in batch]),
        dones=jnp.array([e.done for e in batch]),
    )

small_memory = deque(maxlen=1000)

state = env.reset()
action = env.action_space.sample()
for _ in range(200):
    action = env.action_space.sample()
    next_state, reward, is_done, *_ = env.step(action)
    experience = Experience(state, action, reward, next_state, is_done)
    small_memory.append(experience)

batch = get_random_batch(small_memory)

batch.states.shape



(32, 8)

In [8]:

initial_params = network.init(
    jax.random.PRNGKey(seed=0), batch.states)
initial_opt_state = optimiser.init(initial_params)
train_state = TrainingState(initial_params, initial_params, initial_params, initial_opt_state)

  leaves, treedef = jax.tree_flatten(tree)
  return jax.tree_unflatten(treedef, leaves)


## What we need

- [x] Environment 
- [x] Memory Buffer
- [x] DQN model
- [] loss function
- [] Training Loop


In [9]:
# def loss(params: hk.Params, batch: Batch) -> jnp.ndarray:
params = train_state.params
state_actions_values = network.apply(params, batch.next_states)
max_state_actions_values = jnp.max(state_actions_values, axis=1)
targets = batch.rewards + TrainConfig.GAMMA * jnp.where(batch.dones, 0.0, max_state_actions_values)

q_values = network.apply(params, batch.states)



In [10]:
q_values.shape

(32, 4)

In [11]:

q_value_for_action_taken = q_values[jnp.arange(q_values.shape[0]), batch.actions]


In [12]:
print(q_values[0])
print(batch.actions[0])
print(q_value_for_action_taken[0])
assert q_value_for_action_taken[0] == q_values[0][batch.actions[0]]

[ 0.3777553   0.03897272  0.13142307 -0.01228983]
0
0.3777553


In [13]:
def loss(params, target_params, batch):
    q_values = network.apply(params, batch.states)
    q_values_pred = q_values[jnp.arange(q_values.shape[0]), batch.actions]

    q_values_next = target_network.apply(target_params, batch.next_states)
    q_values_next_max = jnp.max(q_values_next, axis=1)

    q_value_true = batch.rewards + TrainConfig.GAMMA * jnp.where(batch.dones, 0.0, q_values_next_max)
    return jnp.mean((q_values_pred - q_value_true) ** 2)

In [14]:
@jax.jit
def update(train_state: TrainingState, batch: Batch) -> TrainingState:
    """Learning rule (stochastic gradient descent)."""
    grads = jax.grad(loss)(train_state.params, train_state.target_params, batch)
    updates, opt_state = optimiser.update(grads, train_state.opt_state)
    params = optax.apply_updates(train_state.params, updates)

    # Update target network.
    # params * TAU + (1 - TAU) * new_params
    # target_params = params * TrainConfig.TAU  + (1 - TrainConfig.TAU) * train_state.target_params
    target_params = optax.incremental_update(params, train_state.target_params, TrainConfig.TAU)
    
    # Compute avg_params, the exponential moving average of the "live" params.
    # We use this only for evaluation (cf. https://doi.org/10.1137/0330046).
    eval_params = optax.incremental_update(
        params, train_state.eval_params, step_size=0.001)
    return TrainingState(params, target_params, eval_params, opt_state)

In [15]:
state = env.reset()
batch_state = jnp.array([state, state])
network.apply(train_state.params, state)
network.apply(train_state.params, batch_state)

DeviceArray([[0.30043876, 0.21810007, 0.20305528, 0.05785726],
             [0.30043876, 0.21810007, 0.20305528, 0.05785726]],            dtype=float32)

In [16]:

def update_epsilon(epsilon, train_config: TrainConfig):
    return max(train_config.E_MIN, train_config.E_DECAY*epsilon)

def exploit_or_explore(q_value: jnp.ndarray, epsilon: float = 0.1) -> int:
    """Exploit or explore according to epsilon-greedy policy."""
    if random.random() < epsilon:
        return env.action_space.sample()
    else:
        return np.array(jnp.argmax(q_value))

def is_update_params(n_steps_taken: int, train_config: TrainConfig) -> bool:
    """Update params every `update_params_every` steps."""
    return (n_steps_taken + 1) % train_config.UPDATE_PARAMS_EVERY_N_STEPS == 0


In [17]:
params_at_different_training_steps = {}
total_reward_history = []
moving_average_window_size = 100
epsilon = 1.0
train_config = TrainConfig()
for episode in range(train_config.N_EPISODES):
    state = env.reset()
    total_reward = 0.0

    for step in range(train_config.MAX_N_STEPS_PER_EPISODE):
        q_value = network.apply(train_state.params, state)
        action = exploit_or_explore(q_value=q_value, epsilon=epsilon)
        
        next_state, reward, is_done, *_ = env.step(action)
        experience = Experience(state, action, reward, next_state, is_done)
        memory.append(experience)
        if len(memory) < TrainConfig.MEMORY_SIZE:
            state = next_state
            total_reward += reward
            if is_done:
                break
            continue

        if is_update_params(step, train_config=train_config):
            batch =get_random_batch(memory)
            train_state = update(train_state, batch)

        state = next_state
        total_reward += reward
        if is_done:
            break

    total_reward_history.append(total_reward)
    mean_total_reward_in_window = np.mean(total_reward_history[-moving_average_window_size:])
    epsilon = update_epsilon(epsilon, train_config)


    print(f"\rEpisode {episode+1} | Total point average of the last {moving_average_window_size} episodes: {mean_total_reward_in_window:.2f}", end="")

    if (episode+1) % moving_average_window_size == 0:
        print(f"\rEpisode {episode+1} | Total point average of the last {moving_average_window_size} episodes: {mean_total_reward_in_window:.2f}")

    if (episode+1) % 100 == 0:
        network_params_name = f"params_episode_{episode + 1}"
        params_at_different_training_steps[network_params_name] = train_state.params

    # We will consider that the environment is solved if we get an
    # average of 200 points in the last 100 episodes.
    if mean_total_reward_in_window >= 200.0:
        print(f"\n\nEnvironment solved in {episode+1} episodes!")
        network_params_name = f"params_episode_final"
        params_at_different_training_steps[network_params_name] = train_state.params

Episode 100 | Total point average of the last 100 episodes: -170.87
Episode 200 | Total point average of the last 100 episodes: -147.81
Episode 300 | Total point average of the last 100 episodes: -39.286
Episode 400 | Total point average of the last 100 episodes: 40.056
Episode 500 | Total point average of the last 100 episodes: 94.68
Episode 600 | Total point average of the last 100 episodes: 144.83
Episode 664 | Total point average of the last 100 episodes: 178.36

In [1]:
import base64
import imageio
import IPython

def create_video(filename, env, train_state, fps=30):
    max_steps = 300
    steps = 0
    with imageio.get_writer(filename, fps=fps) as video:
        done = False
        state = env.reset()
        frame = env.render(mode="rgb_array")
        video.append_data(frame)
        while not done:    
            steps +=1
            q_values = network.apply(train_state.params, state)
            action = jnp.argmax(q_values)
            state, _, done, *_ = env.step(np.asarray([action])[0])
            frame = env.render(mode="rgb_array")
            video.append_data(frame)
            if done and steps < max_steps:
                while steps < max_steps:
                    video.append_data(frame)
                    steps += 1
            if steps >= max_steps:
                break



In [2]:
import os
try:
    os.makedirs('./videos')
except FileExistsError:
    pass
for episode_name, episode_params in params_at_different_training_steps.items():
    filename = f"./videos/lunar_{episode_name}.gif"
    episode_train_state = TrainingState(episode_params, episode_params, episode_params, initial_opt_state)
    create_video(filename, env, episode_train_state)

NameError: name 'params_at_different_training_steps' is not defined


| | | |
|:-------------------------:|:-------------------------:|:-------------------------:|
|<img width="1604" alt="screen shot 2017-08-07 at 12 18 15 pm" src="./videos/lunar_params_episode_100.gif">  100 Episodes |  <img width="1604" alt="screen shot 2017-08-07 at 12 18 15 pm" src="videos/lunar_params_episode_200.gif"> 200 Episodes|<img width="1604" alt="screen shot 2017-08-07 at 12 18 15 pm" src="videos/lunar_params_episode_300.gif"> 300 Episodes |
|<img width="1604" alt="screen shot 2017-08-07 at 12 18 15 pm" src="./videos/lunar_params_episode_400.gif">  400 Episodes |  <img width="1604" alt="screen shot 2017-08-07 at 12 18 15 pm" src="videos/lunar_params_episode_500.gif"> 500 Episodes|<img width="1604" alt="screen shot 2017-08-07 at 12 18 15 pm" src="videos/lunar_params_episode_600.gif"> 600 Episodes |
|<img width="1604" alt="screen shot 2017-08-07 at 12 18 15 pm" src="./videos/lunar_params_episode_700.gif">  700 Episodes |  <img width="1604" alt="screen shot 2017-08-07 at 12 18 15 pm" src="videos/lunar_params_episode_800.gif"> 800 Episodes|<img width="1604" alt="screen shot 2017-08-07 at 12 18 15 pm" src="videos/lunar_params_episode_final.gif"> Expert Level Episodes |

In [60]:
import tensorflow as tf
import tree
from jax.experimental import jax2tf
import sonnet as snt



def create_variable(path, value):
  name = '/'.join(map(str, path)).replace('~', '_')
  return tf.Variable(value, name=name)
N_FEATURES_PER_STATE = 8
polymorphic_state_shape = jax2tf.shape_poly.PolyShape(
  "b", N_FEATURES_PER_STATE
)

class JaxModule(snt.Module):
  def __init__(self, params, apply_fn, name=None):
    super().__init__(name=name)
    self._params = tree.map_structure_with_path(create_variable, params)
    # self._apply = jax2tf.convert(lambda p, x: apply_fn(p, x), polymorphic_shapes=[None, "b, 8"])
    self._apply = jax2tf.convert(lambda p, x: apply_fn(p, x), polymorphic_shapes=[None, polymorphic_state_shape])
    self._apply = tf.autograph.experimental.do_not_convert(self._apply)

  def __call__(self, inputs):
    return self._apply(self._params, inputs)


# network = hk.without_apply_rng(hk.transform(network_fn))
# target_network = hk.without_apply_rng(hk.transform(network_fn))
# optimiser = optax.adam(1e-3)
# initial_params = network.init(
#     jax.random.PRNGKey(seed=0), batch.states)
net = JaxModule(train_state.params, network.apply)
[v.name for v in net.trainable_variables]



  import distutils as _distutils
2022-08-22 12:15:41.382768: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
  import imp
  'nearest': pil_image.NEAREST,
  'bilinear': pil_image.BILINEAR,
  'bicubic': pil_image.BICUBIC,
  'hamming': pil_image.HAMMING,
  'box': pil_image.BOX,
  'lanczos': pil_image.LANCZOS,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  BoolLike = Union[bool, np.bool, TensorLike]
2022-08-22 12:15:43.342022: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2022-08-22 12:15:43.342059: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)


['jax_module/linear/b:0',
 'jax_module/linear/w:0',
 'jax_module/linear_1/b:0',
 'jax_module/linear_1/w:0',
 'jax_module/linear_2/b:0',
 'jax_module/linear_2/w:0']

In [61]:
network.apply(initial_params, state)

  leaves, treedef = jax.tree_flatten(tree)
  return jax.tree_unflatten(treedef, leaves)


DeviceArray([ 0.46667343, -0.03591029, -0.29230744, -0.16951452], dtype=float32)

In [62]:

@tf.function(autograph=False, input_signature=[tf.TensorSpec([None, 8])])
def forward(x):
  return net(x)

to_save = tf.Module()
to_save.forward = forward
to_save.params = list(net.variables)

In [63]:
tf.saved_model.save(to_save, "./lunar_lander_model")

  leaves, treedef = jax.tree_flatten(tree)
  return jax.tree_unflatten(treedef, leaves)


In [64]:
loaded = tf.saved_model.load("./lunar_lander_model/")
preds = loaded.forward(tf.ones([3, 8]))
preds_with_serving_default = loaded.signatures['serving_default'](x=tf.ones([2,8]))

In [65]:
state

array([-0.5280086 , -0.14223678, -0.32590032, -0.2560407 ,  0.36866826,
       -3.6965518 ,  0.        ,  1.        ], dtype=float32)

In [66]:
import tensorflow as tf
loaded = tf.saved_model.load("./lunar_lander_model/")

<a name="6"></a>
## 6 - Deep Q-Learning

In cases where both the state and action space are discrete we can estimate the action-value function iteratively by using the Bellman equation:

$$
Q_{i+1}(s,a) = R + \gamma \max_{a'}Q_i(s',a')
$$

This iterative method converges to the optimal action-value function $Q^*(s,a)$ as $i\to\infty$. This means that the agent just needs to gradually explore the state-action space and keep updating the estimate of $Q(s,a)$ until it converges to the optimal action-value function $Q^*(s,a)$. However, in cases where the state space is continuous it becomes practically impossible to explore the entire state-action space. Consequently, this also makes it practically impossible to gradually estimate $Q(s,a)$ until it converges to $Q^*(s,a)$.

In the Deep $Q$-Learning, we solve this problem by using a neural network to estimate the action-value function $Q(s,a)\approx Q^*(s,a)$. We call this neural network a $Q$-Network and it can be trained by adjusting its weights at each iteration to minimize the mean-squared error in the Bellman equation.

Unfortunately, using neural networks in reinforcement learning to estimate action-value functions has proven to be highly unstable. Luckily, there's a couple of techniques that can be employed to avoid instabilities. These techniques consist of using a ***Target Network*** and ***Experience Replay***. We will explore these two techniques in the following sections.

### 6.1 Target Network

We can train the $Q$-Network by adjusting it's weights at each iteration to minimize the mean-squared error in the Bellman equation, where the target values are given by:

$$
y = R + \gamma \max_{a'}Q(s',a';w)
$$

where $w$ are the weights of the $Q$-Network. This means that we are adjusting the weights $w$ at each iteration to minimize the following error:

$$
\overbrace{\underbrace{R + \gamma \max_{a'}Q(s',a'; w)}_{\rm {y~target}} - Q(s,a;w)}^{\rm {Error}}
$$

Notice that this forms a problem because the $y$ target is changing on every iteration. Having a constantly moving target can lead to oscillations and instabilities. To avoid this, we can create
a separate neural network for generating the $y$ targets. We call this separate neural network the **target $\hat Q$-Network** and it will have the same architecture as the original $Q$-Network. By using the target $\hat Q$-Network, the above error becomes:

$$
\overbrace{\underbrace{R + \gamma \max_{a'}\hat{Q}(s',a'; w^-)}_{\rm {y~target}} - Q(s,a;w)}^{\rm {Error}}
$$

where $w^-$ and $w$ are the weights the target $\hat Q$-Network and $Q$-Network, respectively.

In practice, we will use the following algorithm: every $C$ time steps we will use the $\hat Q$-Network to generate the $y$ targets and update the weights of the target $\hat Q$-Network using the weights of the $Q$-Network. We will update the weights $w^-$ of the the target $\hat Q$-Network using a **soft update**. This means that we will update the weights $w^-$ using the following rule:
 
$$
w^-\leftarrow \tau w + (1 - \tau) w^-
$$

where $\tau\ll 1$. By using the soft update, we are ensuring that the target values, $y$, change slowly, which greatly improves the stability of our learning algorithm.