Dyna-Q


In [1]:
import random
import gymnasium as gym
import jax
import jax.numpy as jnp
import optax
import flax.linen as nn

class QNetwork(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=64)(x)
        x = nn.relu(x)
        x = nn.Dense(features=32)(x)
        x = nn.relu(x)
        x = nn.Dense(features=2)(x)
        return x

class DynamicsModel(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=64)(x)
        x = nn.relu(x)
        x = nn.Dense(features=32)(x)
        x = nn.relu(x)
        x = nn.Dense(features=5)(x)
        return x

def epsilon_greedy_policy(q_values, epsilon):
    if random.random() < epsilon:
        return random.randint(0, 1)
    else:
        return int(jnp.argmax(q_values))

def dyna_q(env, q_params, q_opt_state, dynamics_params, dynamics_opt_state, episodes=500, planning_steps=5, epsilon=0.1, gamma=0.99, alpha=0.5):
    rng = jax.random.PRNGKey(0)

    # initialize q network
    q_model = QNetwork()
    q_values = lambda params, s: q_model.apply(params, s.reshape(1, -1))

    # initialize dynamic network
    dynamics_model = DynamicsModel()
    predict_dynamics = lambda params, s, a: dynamics_model.apply(params, jnp.hstack([s, jnp.array([a])]).reshape(1, -1))

    # initialize buffer
    transition_memory = []

    # for each episode
    for episode in range(episodes):
        s, _ = env.reset()
        done = False

        # for each step
        while not done:
            # choose an action a from current state s based on eps-greedy strategy using q values
            a = epsilon_greedy_policy(q_values(q_params, s), epsilon)
            # execute the action and observe the resuting reward and the next state
            s_next, r, done, _, _ = env.step(a)
            # update buffer
            transition_memory.append((s, a, r, s_next))

            # calculate the target q-value for the current state-action pair
            # [0] is batch
            q_s = q_values(q_params, s)
            target = q_s[0].copy()
            if done:
                target = target.at[a].set(r)
            else:
                target = target.at[a].set(r + gamma * jnp.max(q_values(q_params, s_next)))

            # update q-network using temporal difference error, which is the difference between the current estimate of the q-value and the target-q value
            loss_fn = lambda params: jnp.mean((q_model.apply(params, s.reshape(1, -1)) - target) ** 2)
            losses, grad = jax.value_and_grad(loss_fn)(q_params)
            #_, grad = grad_fn(q_params)
            updates, q_opt_state = q_optimizer.update(grad, q_opt_state)
            q_params = optax.apply_updates(q_params, updates)

            # update the dynamic network if the episode is not done
            if not done:
                inputs = jnp.hstack([s, jnp.array([a])])
                target_dynamics = jnp.hstack([r, s_next])
                print(target_dynamics.shape)
                print(dynamics_model.apply(dynamics_params, inputs.reshape(1, -1)))
                loss_fn_dynamics = lambda params: jnp.mean((dynamics_model.apply(params, inputs.reshape(1, -1)) - target_dynamics) ** 2)
                losses_dynamics, grad_dynamics = jax.value_and_grad(loss_fn_dynamics)(dynamics_params)
                #_,grad_dynamics= grad_fn_dynamics(dynamics_params)
                updates, dynamics_opt_state = dynamics_optimizer.update(grad_dynamics, dynamics_opt_state)
                dynamics_params = optax.apply_updates(dynamics_params, updates)

            # perform planning step using the transition memory
            for _ in range(planning_steps):
                # Sample a random transition from memory
                s_sample, a_sample, r_sample, s_next_sample = random.choice(transition_memory)

                # Use the environment model to predict the next state and reward
                r_pred, s_next_pred = predict_dynamics(dynamics_params, s_sample, a_sample).ravel()[0], predict_dynamics(dynamics_params, s_sample, a_sample).ravel()[1:]

                # Calculate the target Q-value for the sampled state-action pair
                q_s_sample = q_values(q_params, s_sample)
                target_sample = q_s_sample[0].copy()

                if s_next_sample is None:
                    target_sample = target_sample.at[a_sample].set(r_pred)
                else:
                    target_sample = target_sample.at[a_sample].set(r_pred + gamma * jnp.max(q_values(q_params, s_next_pred)))

                # Update the Q-value function using the sampled transition
                loss_fn_sample = lambda params: jnp.mean((q_model.apply(params, s_sample.reshape(1, -1)) - target_sample) ** 2)
                _, grad_sample = jax.value_and_grad(loss_fn_sample)(q_params)
                #grad_sample, _ = grad_fn_sample(q_params)
                updates, q_opt_state = q_optimizer.update(grad_sample, q_opt_state)
                q_params = optax.apply_updates(updates, q_params)


            s = s_next


  jax.tree_util.register_keypaths(


In [2]:

env = gym.make("CartPole-v1")
rng_q = jax.random.PRNGKey(0)
rng_dynamics = jax.random.PRNGKey(1)

q_model = QNetwork()
q_params = q_model.init(rng_q, jnp.ones((1, 4)))
q_optimizer = optax.adam(1e-4)
q_opt_state = q_optimizer.init(q_params)

dynamics_model = DynamicsModel()
dynamics_params = dynamics_model.init(rng_dynamics, jnp.ones((1, 5)))
dynamics_optimizer = optax.adam(1e-4)
dynamics_opt_state = dynamics_optimizer.init(dynamics_params)

dyna_q(env, q_params, q_opt_state, dynamics_params, dynamics_opt_state)
env.close()




No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


(5,)
[[ 0.00411973 -0.00157021  0.01296842  0.01997603  0.00129607]]
(5,)
[[0.04237485 0.03214698 0.04422821 0.02989962 0.07982983]]
(5,)
[[0.09223083 0.05695689 0.09872299 0.05783403 0.16825108]]
(5,)
[[0.14141014 0.08316255 0.15441047 0.08640867 0.25840452]]
(5,)
[[0.19346826 0.10895582 0.21136934 0.11200678 0.35037637]]
(5,)
[[0.24827321 0.13466951 0.26957408 0.13468334 0.444739  ]]
(5,)
[[0.3054441  0.16079292 0.32884946 0.15498595 0.5421034 ]]
(5,)
[[ 0.02422397 -0.00266501  0.00220547  0.00268421  0.0008713 ]]
(5,)
[[0.07876007 0.022644   0.04799666 0.03213102 0.07967476]]
(5,)
[[ 0.43634105  0.04074448 -0.00952566  0.44173488  0.22512881]]
(5,)
[[0.07863867 0.02850013 0.05502581 0.02821335 0.09160579]]
(5,)
[[0.13234176 0.04644139 0.11458489 0.05608358 0.19091175]]
(5,)
[[0.18750185 0.06887501 0.16938804 0.08336797 0.2855445 ]]
(5,)
[[0.24506897 0.09291624 0.22480732 0.10814114 0.3823518 ]]
(5,)
[[0.30487266 0.11720402 0.28151172 0.13071866 0.4819183 ]]
(5,)
[[0.3669359  0.14189

KeyboardInterrupt: 