In [1]:
import jax
import jax.numpy as jnp
import jaxlib
import optax
import haiku as hk
import gym

  PyTreeDef = type(jax.tree_structure(None))


In [4]:
env = gym.make('LunarLander-v2', new_step_api=True)
print(env.action_space.n)

4


In [5]:
NUM_ACTIONS = env.action_space.n


## What we need

- [x] Environment 
- [] Memory Buffer
- [] DQN model
- [] loss function
- [] Training Loop


In [6]:
class TrainConfig:
    MEMORY_SIZE = 10000
    BATCH_SIZE = 64
    UPDATE_PARAMS_EVERY_N_STEPS = 4
    TAU = 0.001
    E_MIN = 0.01
    E_DECAY = 0.995
    N_EPISODES = 2000
    MAX_N_STEPS_PER_EPISODE = 1000
    

In [7]:
import dataclasses
from typing import NamedTuple
from collections import deque


@dataclasses.dataclass
class Experience:
    state: jnp.ndarray
    action: int
    reward: float
    next_state: jnp.ndarray
    done: bool


memory = deque(maxlen=TrainConfig.MEMORY_SIZE)

## What we need

- [x] Environment 
- [x] Memory Buffer
- [] DQN model
- [] loss function
- [] Training Loop


In [8]:
from typing import NamedTuple
# @dataclasses.dataclass
class TrainingState(NamedTuple):
    params: hk.Params
    target_params: hk.Params
    eval_params: hk.Params
    opt_state: optax.OptState

# @dataclasses.dataclass
class Batch(NamedTuple):
    states: jnp.ndarray
    actions: int
    rewards: float
    next_states: jnp.ndarray
    dones: bool



def network_fn(x: jnp.ndarray) -> jnp.ndarray:
    model = hk.Sequential(
        [
            hk.Linear(64),jax.nn.relu,
            hk.Linear(64), jax.nn.relu,
            hk.Linear(NUM_ACTIONS),
        ]

    )
    return model(x)

network = hk.without_apply_rng(hk.transform(network_fn))
target_network = hk.without_apply_rng(hk.transform(network_fn))
optimiser = optax.adam(1e-3)
# Initialise network and optimiser; note we draw an input to get shapes.

In [9]:
import jax.random as jrandom
import random
import numpy as np
keygen = jrandom.PRNGKey(0)

def get_random_batch(memory):
    batch = random.sample(memory, k=TrainConfig.BATCH_SIZE)
    return Batch(
        states=jnp.array([e.state for e in batch]),
        actions=jnp.array([e.action for e in batch]),
        rewards=jnp.array([e.reward for e in batch]),
        next_states=jnp.array([e.next_state for e in batch]),
        dones=jnp.array([e.done for e in batch]),
    )

small_memory = deque(maxlen=1000)

state = env.reset()
action = env.action_space.sample()
for _ in range(200):
    action = env.action_space.sample()
    next_state, reward, is_done, *_ = env.step(action)
    experience = Experience(state, action, reward, next_state, is_done)
    small_memory.append(experience)

batch = get_random_batch(small_memory)

batch.states.shape



(64, 8)

In [10]:

initial_params = network.init(
    jax.random.PRNGKey(seed=0), batch.states)
initial_opt_state = optimiser.init(initial_params)
train_state = TrainingState(initial_params, initial_params, initial_params, initial_opt_state)

  leaves, treedef = jax.tree_flatten(tree)
  return jax.tree_unflatten(treedef, leaves)


## What we need

- [x] Environment 
- [x] Memory Buffer
- [x] DQN model
- [] loss function
- [] Training Loop


In [11]:
# def loss(params: hk.Params, batch: Batch) -> jnp.ndarray:
params = train_state.params
state_actions_values = network.apply(params, batch.next_states)
max_state_actions_values = jnp.max(state_actions_values, axis=1)
targets = batch.rewards + jnp.where(batch.dones, 0.0, max_state_actions_values)

q_values = network.apply(params, batch.states)



In [12]:
q_values.shape

(64, 4)

In [13]:

q_value_for_action_taken = q_values[jnp.arange(q_values.shape[0]), batch.actions]


In [14]:
print(q_values[0])
print(batch.actions[0])
print(q_value_for_action_taken[0])
assert q_value_for_action_taken[0] == q_values[0][batch.actions[0]]

[0.2785838  0.16615024 0.14843863 0.02754669]
2
0.14843863


In [15]:
def loss(params, target_params, batch):
    q_values = network.apply(params, batch.states)
    q_values_pred = q_values[jnp.arange(q_values.shape[0]), batch.actions]

    q_values_next = target_network.apply(target_params, batch.next_states)
    q_values_next_max = jnp.max(q_values_next, axis=1)

    q_value_true = batch.rewards + jnp.where(batch.dones, 0.0, q_values_next_max)
    return jnp.mean((q_values_pred - q_value_true) ** 2)

In [16]:
@jax.jit
def update(train_state: TrainingState, batch: Batch) -> TrainingState:
    """Learning rule (stochastic gradient descent)."""
    grads = jax.grad(loss)(train_state.params, train_state.target_params, batch)
    updates, opt_state = optimiser.update(grads, train_state.opt_state)
    params = optax.apply_updates(train_state.params, updates)

    # Update target network.
    # params * TAU + (1 - TAU) * new_params
    # target_params = params * TrainConfig.TAU  + (1 - TrainConfig.TAU) * train_state.target_params
    target_params = optax.incremental_update(params, train_state.target_params, TrainConfig.TAU)
    
    # Compute avg_params, the exponential moving average of the "live" params.
    # We use this only for evaluation (cf. https://doi.org/10.1137/0330046).
    eval_params = optax.incremental_update(
        params, train_state.eval_params, step_size=0.001)
    return TrainingState(params, target_params, eval_params, opt_state)

In [17]:
state =env.reset()
batch_state = jnp.array([state, state])
network.apply(train_state.params, state)
network.apply(train_state.params, batch_state)

DeviceArray([[0.24135946, 0.03655888, 0.07189564, 0.03146065],
             [0.24135946, 0.03655888, 0.07189564, 0.03146065]],            dtype=float32)

In [18]:

def update_epsilon(epsilon, train_config: TrainConfig):
    return max(train_config.E_MIN, train_config.E_DECAY*epsilon)

def exploit_or_explore(q_value: jnp.ndarray, epsilon: float = 0.1) -> int:
    """Exploit or explore according to epsilon-greedy policy."""
    if random.random() < epsilon:
        return env.action_space.sample()
    else:
        return np.array(jnp.argmax(q_value))

def is_update_params(n_steps_taken: int, train_config: TrainConfig) -> bool:
    """Update params every `update_params_every` steps."""
    return (n_steps_taken + 1) % train_config.UPDATE_PARAMS_EVERY_N_STEPS == 0


In [19]:
    
total_reward_history = []
moving_average_window_size = 100
epsilon = 1.0
train_config = TrainConfig()
for episode in range(train_config.N_EPISODES):
    state = env.reset()
    total_reward = 0.0

    for step in range(train_config.MAX_N_STEPS_PER_EPISODE):
        q_value = network.apply(train_state.params, state)
        action = exploit_or_explore(q_value=q_value, epsilon=epsilon)
        
        next_state, reward, is_done, *_ = env.step(action)
        experience = Experience(state, action, reward, next_state, is_done)
        memory.append(experience)
        if len(memory) < TrainConfig.MEMORY_SIZE:
            state = next_state
            total_reward += reward
            if is_done:
                break
            continue

        if is_update_params(step, train_config=train_config):
            batch =get_random_batch(memory)
            train_state = update(train_state, batch)

        state = next_state
        total_reward += reward
        if is_done:
            break

    total_reward_history.append(total_reward)
    mean_total_reward_in_window = np.mean(total_reward_history[-moving_average_window_size:])
    epsilon = update_epsilon(epsilon, train_config)


    print(f"\rEpisode {episode+1} | Total point average of the last {moving_average_window_size} episodes: {mean_total_reward_in_window:.2f}", end="")

    if (episode+1) % moving_average_window_size == 0:
        print(f"\rEpisode {episode+1} | Total point average of the last {moving_average_window_size} episodes: {mean_total_reward_in_window:.2f}")

    # We will consider that the environment is solved if we get an
    # average of 200 points in the last 100 episodes.
    if mean_total_reward_in_window >= 200.0:
        print(f"\n\nEnvironment solved in {episode+1} episodes!")
        # q_network.save('lunar_lander_model.h5')
        break

Episode 28 | Total point average of the last 100 episodes: -157.51

KeyboardInterrupt: 

In [20]:
q_values, action, np.asarray([action])[0]

(DeviceArray([[0.2785838 , 0.16615024, 0.14843863, 0.02754669],
              [0.2785838 , 0.16615024, 0.14843863, 0.02754669],
              [0.2785838 , 0.16615024, 0.14843863, 0.02754669],
              [0.2785838 , 0.16615024, 0.14843863, 0.02754669],
              [0.2785838 , 0.16615024, 0.14843863, 0.02754669],
              [0.2785838 , 0.16615024, 0.14843863, 0.02754669],
              [0.2785838 , 0.16615024, 0.14843863, 0.02754669],
              [0.2785838 , 0.16615024, 0.14843863, 0.02754669],
              [0.2785838 , 0.16615024, 0.14843863, 0.02754669],
              [0.2785838 , 0.16615024, 0.14843863, 0.02754669],
              [0.2785838 , 0.16615024, 0.14843863, 0.02754669],
              [0.2785838 , 0.16615024, 0.14843863, 0.02754669],
              [0.2785838 , 0.16615024, 0.14843863, 0.02754669],
              [0.2785838 , 0.16615024, 0.14843863, 0.02754669],
              [0.2785838 , 0.16615024, 0.14843863, 0.02754669],
              [0.2785838 , 0.16615024, 0

In [21]:
# import matplotlib.pyplot as plt 
# import time
# done = False
# state = env.reset()
# frame = env.render(mode="rgb_array")
# for _ in range(100):    
#     q_values = network.apply(train_state.params, state)
#     action = jnp.argmax(q_values)
#     state, _, done, *_ = env.step(np.asarray([action])[0])
#     env.render()
#     if done:
#         env.close()

In [22]:
import base64

import imageio
import IPython

def create_video(filename, env, train_state, fps=30):
    with imageio.get_writer(filename, fps=fps) as video:
        done = False
        state = env.reset()
        frame = env.render(mode="rgb_array")
        video.append_data(frame)
        while not done:    
            q_values = network.apply(train_state.params, state)
            action = jnp.argmax(q_values)
            state, _, done, *_ = env.step(np.asarray([action])[0])
            frame = env.render(mode="rgb_array")
            video.append_data(frame)



In [24]:
from IPython.display import Video
filename = "./lunar_lander.mp4"
create_video(filename, env, train_state)
Video(filename)

See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(


In [71]:
import tensorflow as tf
import tree
from jax.experimental import jax2tf
import sonnet as snt


def network_fn(x: jnp.ndarray) -> jnp.ndarray:
    model = hk.Sequential(
        [
            hk.Linear(64),jax.nn.relu,
            hk.Linear(64), jax.nn.relu,
            hk.Linear(4),
        ]

    )
    return model(x)



def create_variable(path, value):
  name = '/'.join(map(str, path)).replace('~', '_')
  return tf.Variable(value, name=name)

polymorphic_state_shape = jax2tf.shape_poly.PolyShape(
  "None, 8"
)

class JaxModule(snt.Module):
  def __init__(self, params, apply_fn, name=None):
    super().__init__(name=name)
    self._params = tree.map_structure_with_path(create_variable, params)
    self._apply = jax2tf.convert(lambda p, x: apply_fn(p, x), polymorphic_shapes=[None, "b, 8"])
    self._apply = tf.autograph.experimental.do_not_convert(self._apply)

  def __call__(self, inputs):
    return self._apply(self._params, inputs)


network = hk.without_apply_rng(hk.transform(network_fn))
# network = hk.transform(network_fn)
target_network = hk.without_apply_rng(hk.transform(network_fn))
optimiser = optax.adam(1e-3)
initial_params = network.init(
    jax.random.PRNGKey(seed=0), batch.states)
net = JaxModule(initial_params, network.apply)
[v.name for v in net.trainable_variables]



  leaves, treedef = jax.tree_flatten(tree)
  return jax.tree_unflatten(treedef, leaves)


['jax_module/linear/b:0',
 'jax_module/linear/w:0',
 'jax_module/linear_1/b:0',
 'jax_module/linear_1/w:0',
 'jax_module/linear_2/b:0',
 'jax_module/linear_2/w:0']

In [72]:
network.apply(initial_params, state)

DeviceArray([ 0.36484796, -0.00536182,  0.11019011,  0.01452891], dtype=float32)

In [75]:

@tf.function(autograph=False, input_signature=[tf.TensorSpec([None, 8])])
def forward(x):
  return net(x)

to_save = tf.Module()
to_save.forward = forward
to_save.params = list(net.variables)

In [76]:
tf.saved_model.save(to_save, "./lunar_lander_model")

  leaves, treedef = jax.tree_flatten(tree)
  return jax.tree_unflatten(treedef, leaves)


In [79]:
loaded = tf.saved_model.load("./lunar_lander_model/")
preds = loaded.forward(tf.ones([3, 8]))

In [31]:
state

array([ 0.0966568 ,  1.3465225 ,  0.3490748 , -0.35275942, -0.16664864,
       -0.20712912,  0.        ,  0.        ], dtype=float32)

In [1]:
import tensorflow as tf
loaded = tf.saved_model.load("./lunar_lander_model/")

Metal device set to: Apple M1 Pro

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB



In [3]:
tf.ones([2, 8], dtype=tf.float64)

<tf.Tensor: shape=(2, 8), dtype=float64, numpy=
array([[1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1.]])>

In [5]:
loaded.signatures['serving_default'](tf.ones([2, 8]))

: 

: 

In [44]:
network.variables

AttributeError: 'Transformed' object has no attribute 'variables'

In [25]:

state = env.reset()
q_values = network.apply(train_state.params, state)
action = jnp.argmax(q_values)
action, q_values

(DeviceArray(3, dtype=int32),
 DeviceArray([70.321625, 69.91499 , 65.94694 , 74.401695], dtype=float32))

In [257]:
min(get_random_batch(memory).rewards)

DeviceArray(-15.972583, dtype=float32)

In [237]:
batch = get_random_batch(memory)

grads = jax.grad(loss)(train_state.params, train_state.target_params, batch)
updates, opt_state = optimiser.update(grads, train_state.opt_state)
params = optax.apply_updates(train_state.params, updates)

# Update target network.
# params * TAU + (1 - TAU) * new_params
# target_params = params * TrainConfig.TAU  + (1 - TrainConfig.TAU) * train_state.target_params
target_params = optax.incremental_update(params, train_state.target_params, TrainConfig.TAU)

# Compute avg_params, the exponential moving average of the "live" params.
# We use this only for evaluation (cf. https://doi.org/10.1137/0330046).
eval_params = optax.incremental_update(
    params, train_state.eval_params, step_size=0.001)

  leaves, treedef = jax.tree_flatten(tree)
  return jax.tree_unflatten(treedef, leaves)


In [206]:
params

{'linear': {'b': DeviceArray([ 0.00099999,  0.00099999, -0.00099999,  0.00099999,
                0.00099999,  0.00099999,  0.00099994,  0.00099999,
               -0.00099999, -0.00099999, -0.00099999,  0.00099999,
               -0.00099999, -0.00099999,  0.00099999,  0.00099997,
               -0.00099999, -0.00099999,  0.00099999, -0.00099999,
               -0.00099999,  0.00099999, -0.00099999,  0.00099999,
                0.00099999, -0.00099999,  0.00099999,  0.00099999,
                0.00099999,  0.00099999, -0.00099999, -0.00099999,
               -0.00099999, -0.00099999, -0.00099999,  0.00099999,
               -0.00099999,  0.        , -0.00099999,  0.00099999,
                0.00099999,  0.00099999, -0.00099999, -0.00099999,
               -0.00099999,  0.00099999, -0.00099999,  0.00099999,
                0.00099999, -0.00099999, -0.00099999,  0.00099999,
               -0.00099999,  0.00099999, -0.00099999,  0.00099999,
               -0.00099999,  0.00099999,  0.000

In [164]:
import numpy as np
np.array(jnp.array([0]))[0]
jnp.array([0])

DeviceArray(0, dtype=int32)