In [1]:
import jax
import numpy as np
import jax.numpy as jnp
import gymnasium as gym
import flax
from flax.training.train_state import TrainState
import optax
import functools
import matplotlib.pyplot as plt
import tqdm

In [2]:
ALPHA = 0.001
GAMMA = 0.99
BATCH_SIZE = 64
CAPACITY = 20000

In [3]:
def smooth_rewards(rewards, window_size=10):
    smoothed_rewards = np.zeros_like(rewards)
    for i in range(len(rewards)):
        window_start = max(0, i - window_size // 2)
        window_end = min(len(rewards), i + window_size // 2 + 1)
        smoothed_rewards[i] = np.mean(rewards[window_start:window_end])
    return smoothed_rewards


def plot_data(mean, std):
    x = range(len(mean))

    plt.plot(x, mean, color='blue', label='Mean')
    plt.plot(x, smooth_rewards(mean), color='orange', label='smoothed')
    plt.fill_between(x, mean - std, mean + std, color='blue',
                     alpha=0.3, label='Mean ± Std')

    plt.xlabel('Steps')
    plt.ylabel('Rewards')
    plt.title('Mean with Standard Deviation')
    plt.legend()
    plt.grid(True)
    plt.show()
    
def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
    slope = (end_e - start_e) / duration
    return max(slope * t + start_e, end_e)

In [4]:
class ValueNetworkMean(flax.linen.Module):
    action_dim: int

    @flax.linen.compact
    def __call__(self, x: jnp.ndarray):
        x = flax.linen.Dense(16)(x)
        x = flax.linen.leaky_relu(x)
        value_stream = flax.linen.Dense(1)(x)
        advantage_stream = flax.linen.Dense(self.action_dim)(x)
        q_values = value_stream + \
            (advantage_stream - jnp.mean(advantage_stream, axis=-1, keepdims=True))

        return q_values


class ValueNetworkMax(flax.linen.Module):
    action_dim: int

    @flax.linen.compact
    def __call__(self, x: jnp.ndarray):
        x = flax.linen.Dense(16)(x)
        x = flax.linen.leaky_relu(x)
        value_stream = flax.linen.Dense(1)(x)
        advantage_stream = flax.linen.Dense(self.action_dim)(x)
        q_values = value_stream + \
            (advantage_stream - jnp.max(advantage_stream, axis=-1, keepdims=True))

        return q_values

In [5]:
class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        self.position = 0

    def push(self, transition):
        if len(self.buffer) < self.capacity:
            self.buffer.append(transition)
        else:
            self.buffer[self.position] = transition
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return np.random.choice(len(self.buffer), size=(batch_size,))

    def get_batch(self, indices):
        states, actions, rewards, next_states, dones = zip(
            *[self.buffer[i] for i in indices])
        states = jnp.array(states)
        actions = jnp.array(actions)
        rewards = jnp.array(rewards)
        next_states = jnp.array(next_states)
        dones = jnp.array(dones)
        return states, actions, rewards, next_states, dones

In [6]:
class TrainState(TrainState):
    target_params: flax.core.FrozenDict

In [7]:
class DuelingDQNMean:
    def __init__(self, env, num_actions, observation_shape, seed=0) -> None:
        self.seed = seed
        self.rng = jax.random.PRNGKey(seed)
        self.num_actions = num_actions
        self.observation_shape = observation_shape
        self.env = env

        self.value = ValueNetworkMean(num_actions)
        self.value_state = TrainState.create(
            apply_fn=self.value.apply,
            params=self.value.init(self.rng, jnp.ones(observation_shape)),
            target_params=self.value.init(
                self.rng, jnp.ones(observation_shape)),
            tx=optax.adam(learning_rate=ALPHA)
        )
        self.value.apply = jax.jit(self.value.apply)
        self.value_state = self.value_state.replace(target_params=optax.incremental_update(
            self.value_state.params, self.value_state.target_params, 0.9))
        self.replay_buffer = ReplayBuffer(CAPACITY)
        self.counter = 1

    def sample(self, state, epsilon=0.1):
        if np.random.uniform() < epsilon:
            return np.random.randint(0, self.num_actions)
        q_values = self.value.apply(self.value_state.params, state)
        action = np.array(q_values).argmax(axis = -1)[0]
        return action

    def update(self, states, actions, rewards, next_states,  dones):
        value_next_target = self.value.apply(
            self.value_state.target_params, next_states)
        value_next_target = jnp.max(value_next_target, axis=-1)
        next_q_value = (rewards + (1 - dones) * GAMMA * value_next_target)

        @jax.jit
        def mse_loss(params):
            value_pred = self.value.apply(params, states)
            value_pred = value_pred[jnp.arange(
                value_pred.shape[0]), actions.squeeze()]
            return ((jax.lax.stop_gradient(next_q_value) - value_pred) ** 2).mean()

        loss_value, grads = jax.value_and_grad(
            mse_loss)(self.value_state.params)
        self.value_state = self.value_state.apply_gradients(grads=grads)
        return loss_value

    def train_single_step(self):
        state = self.env.reset(seed=self.seed)[0]
        key = self.rng
        epsilon = linear_schedule(
            start_e=1, end_e=0.05, duration=500, t=0 if self.counter < 50 else self.counter)

        episode_loss, episode_rewards = 0, 0
        for _ in range(500):
            action = self.sample(np.expand_dims(state, axis=0), epsilon)
            next_state, reward, done, truncated, info = self.env.step(action)
            self.replay_buffer.push(
                [state, action, reward, next_state, done or truncated])
            state = next_state
            episode_rewards += reward

            if truncated or done:
                break

            if len(self.replay_buffer.buffer) > 128:
                indices = self.replay_buffer.sample(BATCH_SIZE)
                states, actions, rewards, next_states, dones = self.replay_buffer.get_batch(
                    indices)
                loss_values = self.update(
                    states, actions, rewards, next_states,  dones)
                episode_loss += loss_values

        self.counter += 1
        return episode_loss,episode_rewards

In [8]:
class Simulation:
    def __init__(self, env_name, algorithm) -> None:
        self.env_name = env_name
        self.algorithm = algorithm
        self.env = gym.make(self.env_name)
        self.num_actions = self.env.action_space.n
        self.observation_shape = self.env.observation_space.shape

    def train(self, episodes=1000):
        self.losses, self.rewards = np.zeros(
            (5, episodes)), np.zeros((5, episodes))

        for seed in range(5):
            self.algo = self.algorithm(
                self.env, self.num_actions, self.observation_shape, seed=seed)
            for ep in tqdm.tqdm(range(1, episodes+1)):
                loss, reward = self.algo.train_single_step()
                self.losses[seed][ep-1] = loss
                self.rewards[seed][ep-1] = reward

In [9]:
cartpole_mean_value = Simulation('CartPole-v1', algorithm=DuelingDQNMean)
cartpole_mean_value.train()

  1%|          | 9/1000 [00:10<18:30,  1.12s/it]


KeyboardInterrupt: 