Double deep q learning

The difference between DDQN and DQN:
- In DQN, the same network is used to select the best action and to estimate the value of that action. This can lead to an overoptimistic estimation of Q-values, which may result in suboptimal policies.
- DDQN decouples the action selection and action value estimation by using two separate networks: the online Q-network (with weights θ) and the target Q-network (with weights θ').
  - the online Q-network is used to select the best action, the target Q-network is used to estimate the value of that action
  - the target value computation in DDQN is as follows:
    - use the online Q to select the best action
    - use the target Q to estimate the value of taking this action
    - compute the target value using Bellman optimality with the target Q-network
  - the online Q network is updated as DQN, and the target Q-network is updated periodically by copying the weights from the online Q-network.
 

In [None]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax import optim
import numpy as np
import gymnasium as gym
import random
from collections import deque
import matplotlib.pyplot as plt


In [None]:

class QNetwork(nn.Module):
    def setup(self):
        self.hidden = nn.Dense(64)
        self.output = nn.Dense(2)

    def __call__(self, x):
        x = nn.relu(self.hidden(x))
        return self.output(x)


In [None]:

class DDQNAgent:
    def __init__(self, state_size, action_size, learning_rate=1e-3, gamma=0.99, buffer_size=10000, batch_size=64,
                 update_target_every=200, epsilon_decay=0.995, epsilon_min=0.01):
        self.state_size = state_size
        self.action_size = action_size
        self.learning_rate = learning_rate
        self.gamma = gamma
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.update_target_every = update_target_every
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min

        self.epsilon = 1.0
        self.memory = deque(maxlen=buffer_size)
        self.steps = 0

        self.Q = QNetwork()
        self.target_Q = QNetwork()
        self.sync_target()

        self.optimizer = optim.Adam(learning_rate=learning_rate).create(self.Q)

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def sync_target(self):
        self.target_Q = self.Q

    def act(self, state, apply_epsilon=True):
        if apply_epsilon and np.random.rand() <= self.epsilon:
            return np.random.randint(self.action_size)
        return jnp.argmax(self.Q(state))

    @jax.jit
    def forward_pass(self, params, state):
        return self.Q.apply(params, state)

    def replay(self):
        if len(self.memory) < self.batch_size:
            return

        minibatch = random.sample(self.memory, self.batch_size)

        states, actions, rewards, next_states, dones = zip(*minibatch)
        states = jnp.array(states)
        actions = jnp.array(actions)
        rewards = jnp.array(rewards)
        next_states = jnp.array(next_states)
        dones = jnp.array(dones)

        def loss_fn(params):
            # q values for individual action q(s,a)
            q_values = self.forward_pass(params, states)
            next_q_values = self.forward_pass(params, next_states)
            target_q_values = self.forward_pass(self.target_Q.params, next_states)

            # q value for taking the action generated from the agent
            action_indices = jax.ops.index_update(jnp.zeros_like(q_values), jax.ops.index[np.arange(self.batch_size), actions], 1)
            q_values = jnp.sum(q_values * action_indices, axis=1)
            # target q value for taking the best action based on target_1_values
            best_actions = jnp.argmax(next_q_values, axis=1)
            action_indices = jax.ops.index_update(jnp.zeros_like(target_q_values), jax.ops.index[np.arange(self.batch_size), best_actions], 1)
            target_q_values = jnp.sum(target_q_values * action_indices, axis=1)

            targets = rewards + (1 - dones) * self.gamma * target_q_values
            return jnp.mean(jnp.square(q_values - targets))

        grad_fn = jax.value_and_grad(loss_fn)
        loss, grad = grad_fn(self.Q.params)
        self.optimizer = self.optimizer.apply_gradient(grad)

        self.epsilon = max(self.epsilon_min, self.epsilon_decay * self.epsilon)

    # train DDQN with given episodes.
    # The training should stop if the average reward over the most recent 100 episodes is greater than 195
    def train(self, env, episodes, render=False, solved_window=100, max_episode_steps=200, reward_threshold=175):
        episode_rewards = []

        for e in range(episodes):
            state = env.reset()
            state = jnp.array(state, dtype=jnp.float32)
            done = False
            total_reward = 0
            steps_in_episode = 0

            while not done:
                if render:
                    env.render()

                action = self.act(state)
                next_state, reward, done, _ = env.step(action)
                next_state = jnp.array(next_state, dtype=jnp.float32)

                self.remember(state, action, reward, next_state, done)
                state = next_state
                total_reward += reward

                self.steps += 1
                steps_in_episode += 1

                if self.steps % self.update_target_every == 0:
                    self.sync_target()

                self.replay()

                # Check if the agent reached max_episode_steps in a single episode
                if steps_in_episode >= max_episode_steps:
                    print(f"Agent reached max_episode_steps in episode {e}")
                    break

            episode_rewards.append(total_reward)

            # Calculate the average reward over the last 'solved_window' episodes
            if e >= solved_window:
                avg_reward = np.mean(episode_rewards[-solved_window:])
                print(f'Episode: {e}, Average Reward: {avg_reward}')

                if avg_reward >= reward_threshold:
                    print(f"CartPole-v1 solved in {e} episodes!")
                    break

        return episode_rewards




In [None]:

env = gym.make("CartPole-v1", render_mode="rgb_array")
state_size = env.observation_space.shape[0]
action_size = env.action_space.n

agent = DDQNAgent(state_size, action_size)

episodes = 500
episode_rewards = agent.train(env, episodes)


In [None]:
# plot training 
def plot_moving_average_reward(episode_rewards, window_size=10):
    cumsum_rewards = np.cumsum(episode_rewards)
    moving_avg_rewards = (cumsum_rewards[window_size:] - cumsum_rewards[:-window_size]) / window_size

    plt.plot(moving_avg_rewards)
    plt.xlabel('Episode')
    plt.ylabel('Moving Average Reward')
    plt.title('Moving Average Reward over Episodes')
    plt.show()

plot_moving_average_reward(episode_rewards)


In [None]:
# test
# need a virtual display for rendering in docker
from pyvirtualdisplay import Display
display = Display(visible=0, size=(1400, 900))
display.start()
from IPython import display as ipythondisplay

# Test the trained agent
n_test_episodes = 10

print("\nTesting the trained agent...")
for episode in range(n_test_episodes):
    state, _ = env.reset()
    state = jnp.array(state, dtype=jnp.float32)

    total_reward = 0
    done = False
    pre_screen = env.render()
    while not done:
        action = agent.act(state)
        next_state, reward, done, _, _ = env.step(action)
        next_state = jnp.array(next_state, dtype=jnp.float32)
        screen = env.render()
        state = next_state
        total_reward += reward
        plt.imshow(screen)
        ipythondisplay.clear_output(wait=True)
        ipythondisplay.display(plt.gcf())

    ipythondisplay.clear_output(wait=True)
    
    print(f"Test Episode {episode + 1}, Total Reward: {total_reward}")

env.close()