In [None]:
import gymnasium as gym
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
import numpy as np
from collections import deque
import random
import matplotlib.pyplot as plt 


In [None]:

class QNetwork(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=64)(x)
        x = nn.relu(x)
        x = nn.Dense(features=64)(x)
        x = nn.relu(x)
        x = nn.Dense(features=2)(x)
        return x

def initialize_model_and_optimizer(rng, learning_rate=1e-3):
    model = QNetwork()
    params = model.init(rng, jnp.zeros((1, 4)))["params"]
    tx = optax.adam(learning_rate)
    opt_state = tx.init(params)
    return model, params, tx, opt_state

#@jax.jit
def train_step(model, params, opt_state, target_params, batch, tx):
    def loss_fn(params):
        states, actions, rewards, next_states, dones = batch
        q_values = model.apply({'params':params}, states)
        q_values = jnp.take_along_axis(q_values, actions[:, None], axis=-1).squeeze()
        
        next_q_values_online = model.apply({'params':params}, next_states)
        next_q_values_target = model.apply({'params':target_params}, next_states)
        
        next_actions = jnp.argmax(next_q_values_online, axis=-1) #
        next_q_values = jnp.take_along_axis(next_q_values_target, next_actions[:, None], axis=-1).squeeze()

        targets = rewards + (1 - dones) * 0.99 * next_q_values
        loss = jnp.mean((targets - q_values) ** 2)
        return loss

    grad = jax.grad(loss_fn)(params)
    updates, opt_state = tx.update(grad, opt_state)
    params = optax.apply_updates(params, updates)

    return params, opt_state

def epsilon_greedy_policy(model, params, observation, epsilon=0.1):
    if random.random() < epsilon:
        return random.randint(0, 1)
    q_values = model.apply({'params':params}, jnp.expand_dims(observation, 0))
    return int(jnp.argmax(q_values))

def update_target_params(state_params, target_params, tau=0.001):
    # there are two ways to update target q-network
    # 1. sync target network and online network at a fixed frequency such as every 100 steps
    # 2. soft update: update at each step but with small amount of updates controlled by tau
    return jax.tree_map(lambda x, y: tau * x + (1 - tau) * y, state_params, target_params)


In [None]:

env = gym.make("CartPole-v1").env

rng = jax.random.PRNGKey(0)
model, params, tx, opt_state = initialize_model_and_optimizer(rng)
target_params = params

replay_buffer = deque(maxlen=10000)
batch_size = 64

num_episodes = 300
epsilon = 1.0
epsilon_decay = 0.995
min_epsilon = 0.01
steps_since_target_update = 0
max_episode_steps = 200

reward_history = []
for episode in range(num_episodes):
    observation, _ = env.reset()
    done = False
    episode_reward = 0
    step_in_episode = 0

    while not done:
        action = epsilon_greedy_policy(model, params, observation, epsilon)
        next_observation, reward, done, _, _ = env.step(action)
        episode_reward += reward
        step_in_episode += 1

        replay_buffer.append((observation, action, reward, next_observation, float(done)))
        observation = next_observation

        if len(replay_buffer) >= batch_size:
            indices = random.sample(range(len(replay_buffer)), batch_size)
            batch = [jnp.stack([replay_buffer[i][j] for i in indices]) for j in range(5)]
            params, opt_state = train_step(model, params, opt_state, target_params, batch, tx)

            # soft update target params
            target_params = update_target_params(params, target_params)
            
        if done:
            epsilon = max(min_epsilon, epsilon * epsilon_decay)
            print(f"Episode: {episode}, Reward: {episode_reward}, Epsilon: {epsilon:.4f}")
    reward_history.append(episode_reward)

    # stop training if the steps in one episode is too long
    if step_in_episode >= max_episode_steps:
        print(f"Agent reached max_episode_steps in episode {episode}.")
        break
    # stop training if average reward reaches requirement
    # Calculate the average reward over the last 'solved_window' episodes
    if episode >= 100:
        avg_reward = np.mean(reward_history[-100:])
        print(f'Episode: {episode}, Average Reward: {avg_reward}')

        if avg_reward >= 175:
            print(f"CartPole-v1 solved in {episode} episodes!")
            break

env.close()


In [None]:
# plot training 
def plot_moving_average_reward(episode_rewards, window_size=100):
    cumsum_rewards = np.cumsum(episode_rewards)
    moving_avg_rewards = (cumsum_rewards[window_size:] - cumsum_rewards[:-window_size]) / window_size

    plt.plot(moving_avg_rewards)
    plt.xlabel('Episode')
    plt.ylabel('Moving Average Reward')
    plt.title('Moving Average Reward over Episodes')
    plt.show()

plot_moving_average_reward(reward_history)