Double deep q learning

The difference between DDQN and DQN:
- In DQN, the same network is used to select the best action and to estimate the value of that action. This can lead to an overoptimistic estimation of Q-values, which may result in suboptimal policies.
- DDQN decouples the action selection and action value estimation by using two separate networks: the online Q-network (with weights θ) and the target Q-network (with weights θ').
  - the online Q-network is used to select the best action, the target Q-network is used to estimate the value of that action
  - the target value computation in DDQN is as follows:
    - use the online Q to select the best action
    - use the target Q to estimate the value of taking this action
    - compute the target value using Bellman optimality with the target Q-network
  - the online Q network is updated as DQN, and the target Q-network is updated periodically by copying the weights from the online Q-network.
 

In [None]:
import gymnasium as gym
import jax
import jax.numpy as jnp
import numpy as np
from collections import deque
import random
from flax import linen as nn
import optax
import matplotlib.pyplot as plt

In [None]:
class QNetwork(nn.Module):
    action_size: int
    
    def setup(self):
        self.dense1 = nn.Dense(features=64)
        self.dense2 = nn.Dense(features=64)
        self.dense3 = nn.Dense(features=self.action_size)

    def __call__(self, x):
        x = nn.relu(self.dense1(x))
        x = nn.relu(self.dense2(x))
        x = self.dense3(x)
        return x

In [None]:
class DDQNAgent:
    def __init__(self, state_size, action_size, rng_key, buffer_size=10000, batch_size=64, gamma=0.99, lr=1e-3, epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=0.995, update_target_every=1000):
        self.state_size = state_size
        self.action_size = action_size
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.gamma = gamma
        self.lr = lr
        self.epsilon = epsilon_start
        self.epsilon_end = epsilon_end
        self.epsilon_decay = epsilon_decay
        self.update_target_every = update_target_every
        self.memory = deque(maxlen=buffer_size)

        rng_key, rng_key_init = jax.random.split(rng_key)
        self.network = QNetwork(action_size)
        self.params = self.network.init(rng_key_init, jnp.ones((state_size,)))
        self.target_params = self.params
        self.optimizer = optax.adam(self.lr)
        self.opt_state = self.optimizer.init(self.params)

        self.steps = 0

    def sync_target(self):
        self.target_params = self.params

    def act(self, state):
        if random.random() <= self.epsilon:
            return random.randint(0, self.action_size - 1)
        else:
            state = jnp.expand_dims(jnp.array(state, dtype=jnp.float32), axis=0)
            q_values = self.network.apply(self.params, state)
            return int(jnp.argmax(q_values))

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def replay(self):
        if len(self.memory) < self.batch_size:
            return

        batch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = map(np.array, zip(*batch))

        states = jnp.array(states, dtype=jnp.float32)
        actions = jnp.array(actions, dtype=jnp.int32)
        rewards = jnp.array(rewards, dtype=jnp.float32)
        next_states = jnp.array(next_states, dtype=jnp.float32)
        dones = jnp.array(dones, dtype=jnp.float32)

        def loss(params):
            q_values = self.network.apply(params, states)
            online_q_next = self.network.apply(params, next_states)
            target_q_next = self.network.apply(self.target_params, next_states)
            next_action = jnp.argmax(online_q_next, axis=-1)
            q_target_next = jax.vmap(lambda s: s[next_action])(target_q_next)
            targets = rewards + self.gamma * (1 - dones) * q_target_next

            q_values = jax.vmap(lambda s: s[actions])(q_values)
            return jnp.mean((targets - q_values) ** 2)

        grad_fn = jax.value_and_grad(loss)
        loss_value, gradients = grad_fn(self.params)
        updates, self.opt_state = self.optimizer.update(gradients, self.opt_state)
        self.params = optax.apply_updates(self.params, updates)

        self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay)
        self.steps += 1
        if self.steps % self.update_target_every == 0:
            self.sync_target()


In [None]:
# Training and testing the DDQN agent
env = gym.make("CartPole-v1", render_mode="rgb_array").env
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
rng_key = jax.random.PRNGKey(42)

agent = DDQNAgent(state_size, action_size, rng_key)

n_episodes = 500
reward_history = []
max_episode_steps=200 # env.spec.max_episode_steps
reward_threshold=175 # env.spec.reward_threshold
solved_window = 100

for episode in range(n_episodes):
    state, _ = env.reset()
    state = jnp.array(state, dtype=jnp.float32)

    total_reward = 0
    done = False
    step_in_episode = 0
    
    while not done:
        action = agent.act(state)
        next_state, reward, done, _, _ = env.step(action)
        next_state = jnp.array(next_state, dtype=jnp.float32)

        agent.remember(state, action, reward, next_state, done)
        agent.replay()

        state = next_state
        total_reward += reward
        step_in_episode += 1

        # check if the max_episode_steps are met. if so, terminate this episode
        if step_in_episode >= max_episode_steps:
            print(f"Agent reached max_episode_steps in episode {episode}.")
            break

    reward_history.append(total_reward)
    print(f"Episode {episode}, Total Reward: {total_reward}")

    # stop training if average reward reaches requirement
    # Calculate the average reward over the last 'solved_window' episodes
    if episode >= solved_window:
        avg_reward = np.mean(reward_history[-solved_window:])
        print(f'Episode: {episode}, Average Reward: {avg_reward}')

        if avg_reward >= reward_threshold:
            print(f"CartPole-v1 solved in {episode} episodes!")
            break

# Plot the historical rewards
plt.plot(reward_history)
plt.xlabel("Episode")
plt.ylabel("Total Reward")
plt.title("Historical Rewards for CartPole-v1")
plt.show()

In [None]:
# plot training 
def plot_moving_average_reward(episode_rewards, window_size=10):
    cumsum_rewards = np.cumsum(episode_rewards)
    moving_avg_rewards = (cumsum_rewards[window_size:] - cumsum_rewards[:-window_size]) / window_size

    plt.plot(moving_avg_rewards)
    plt.xlabel('Episode')
    plt.ylabel('Moving Average Reward')
    plt.title('Moving Average Reward over Episodes')
    plt.show()

plot_moving_average_reward(reward_history)


In [None]:
# test
# need a virtual display for rendering in docker
from pyvirtualdisplay import Display
display = Display(visible=0, size=(1400, 900))
display.start()
from IPython import display as ipythondisplay

# Test the trained agent
n_test_episodes = 10

print("\nTesting the trained agent...")
for episode in range(n_test_episodes):
    state, _ = env.reset()
    state = jnp.array(state, dtype=jnp.float32)

    total_reward = 0
    done = False
    pre_screen = env.render()
    while not done:
        action = agent.act(state)
        next_state, reward, done, _, _ = env.step(action)
        next_state = jnp.array(next_state, dtype=jnp.float32)
        screen = env.render()
        state = next_state
        total_reward += reward
        plt.imshow(screen)
        ipythondisplay.clear_output(wait=True)
        ipythondisplay.display(plt.gcf())

    ipythondisplay.clear_output(wait=True)
    
    print(f"Test Episode {episode + 1}, Total Reward: {total_reward}")

env.close()