In [1]:
import jax
import numpy as np
import jax.numpy as jnp
import gymnasium as gym
import flax
from flax.training.train_state import TrainState
import optax
import functools
import matplotlib.pyplot as plt
import tqdm

In [2]:
ALPHA = 0.001
GAMMA = 0.99

In [3]:
def smooth_rewards(rewards, window_size=10):
    smoothed_rewards = np.zeros_like(rewards)
    for i in range(len(rewards)):
        window_start = max(0, i - window_size // 2)
        window_end = min(len(rewards), i + window_size // 2 + 1)
        smoothed_rewards[i] = np.mean(rewards[window_start:window_end])
    return smoothed_rewards


def plot_data(mean, std):
    x = range(len(mean))

    plt.plot(x, mean, color='blue', label='Mean')
    plt.plot(x, smooth_rewards(mean), color='orange', label='smoothed')
    plt.fill_between(x, mean - std, mean + std, color='blue',
                     alpha=0.3, label='Mean ± Std')

    plt.xlabel('Steps')
    plt.ylabel('Rewards')
    plt.title('Mean with Standard Deviation')
    plt.legend()
    plt.grid(True)
    plt.show()

In [4]:
class PolicyNetwork(flax.linen.Module):
    action_dim: int

    @flax.linen.compact
    def __call__(self, x):
        x = flax.linen.Dense(16)(x)
        x = flax.linen.leaky_relu(x)
        x = flax.linen.Dense(self.action_dim)(x)
        x = flax.linen.softmax(x)
        return x


class BaselineNetwork(flax.linen.Module):

    @flax.linen.compact
    def __call__(self, x):
        x = flax.linen.Dense(16)(x)
        x = flax.linen.leaky_relu(x)
        x = flax.linen.Dense(1)(x)
        return x

In [5]:
class MC_Reinforce:
    def __init__(self, env, num_actions, observation_shape, seed=0):
        self.seed = seed
        self.rng = jax.random.PRNGKey(seed)
        self.num_actions = num_actions
        self.observation_shape = observation_shape
        self.env = env

        self.policy = PolicyNetwork(num_actions)
        self.policy_state = TrainState.create(
            apply_fn=self.policy.apply,
            params=self.policy.init(self.rng, jnp.ones(observation_shape)),
            tx=optax.adam(learning_rate=ALPHA),
        )
        self.policy.apply = jax.jit(self.policy.apply)
        # print(self.policy.tabulate(self.rng, jnp.ones(
        #     self.observation_shape)))

    def sample(self, state):
        probs = self.policy.apply(self.policy_state.params, state)[0]
        return probs

    # @functools.partial(jax.jit, static_argnums=(0,))
    def update(self, states, actions, discounted_rewards):
        @jax.jit
        def log_prob_loss(params):
            probs = self.policy.apply(params, states)
            log_probs = jnp.log(probs)
            actions_new = jax.nn.one_hot(actions, num_classes=self.num_actions)
            prob_reduce = -jnp.sum(log_probs*actions_new, axis=1)
            loss = jnp.mean(prob_reduce*discounted_rewards)
            return loss

        loss, grads = jax.value_and_grad(
            log_prob_loss)(self.policy_state.params)
        self.policy_state = self.policy_state.apply_gradients(grads=grads)
        return loss

    def train_single_step(self):
        state = self.env.reset(seed=self.seed)[0]
        key = self.rng

        episode_rewards = []
        episode_states = []
        episode_actions = []

        for _ in range(500):
            _, key = jax.random.split(key=key)
            probs = self.sample(np.expand_dims(state, axis=0))
            action = np.random.choice(self.num_actions, p=np.array(probs))
            episode_actions.append(action)
            episode_states.append(state)
            next_state, reward, done, truncated, info = self.env.step(action)
            episode_rewards.append(reward)
            state = next_state
            if done or truncated:
                break

        discounted_rewards = jnp.array([sum(reward * (GAMMA ** t) for t, reward in enumerate(episode_rewards[start:]))
                                        for start in range(len(episode_rewards))])
        gamma_t = jnp.array([sum(GAMMA ** t for t, reward in enumerate(episode_rewards[start:]))
                             for start in range(len(episode_rewards))])
        discounted_rewards = (
            discounted_rewards-discounted_rewards.mean())/(discounted_rewards.std()+1e-8)
        episode_states = jnp.array(episode_states)
        episode_actions = jnp.array(episode_actions)
        loss = self.update(episode_states, episode_actions,
                           discounted_rewards*gamma_t)
        return loss, np.sum(episode_rewards)

In [6]:
class MC_Baseline:
    def __init__(self, env, num_actions, observation_shape, seed=0):
        self.seed = seed
        self.rng = jax.random.PRNGKey(seed)
        self.num_actions = num_actions
        self.observation_shape = observation_shape
        self.env = env

        self.policy = PolicyNetwork(num_actions)
        self.policy_state = TrainState.create(
            apply_fn=self.policy.apply,
            params=self.policy.init(self.rng, jnp.ones(observation_shape)),
            tx=optax.adam(learning_rate=ALPHA),
        )
        self.policy.apply = jax.jit(self.policy.apply)

        self.baseline = BaselineNetwork()
        self.baseline_state = TrainState.create(
            apply_fn=self.baseline.apply,
            params=self.baseline.init(self.rng, jnp.zeros(observation_shape)),
            tx=optax.adam(learning_rate=ALPHA),
        )
        self.baseline.apply = jax.jit(self.baseline.apply)
        # print(self.policy.tabulate(self.rng, jnp.ones(
        #     self.observation_shape)))

    def sample(self, state):
        probs = self.policy.apply(self.policy_state.params, state)[0]
        return probs

    # @functools.partial(jax.jit, static_argnums=(0,))
    def update(self, states, actions, discounted_rewards, gamma_t):
        @jax.jit
        def mse_loss(params):
            v_s = self.baseline.apply(params, states)
            delta = jnp.subtract(discounted_rewards, jnp.reshape(v_s, (-1,)))
            loss = jnp.mean(0.5*jnp.square(delta))
            return loss, delta

        (loss_baseline, delta), grads_baseline = jax.value_and_grad(
            mse_loss, has_aux=True)(self.baseline_state.params)

        @jax.jit
        def log_prob_loss(params):
            probs = self.policy.apply(params, states)
            log_probs = jnp.log(probs)
            actions_new = jax.nn.one_hot(actions, num_classes=self.num_actions)
            prob_reduce = -jnp.sum(log_probs*actions_new, axis=1)
            loss = jnp.mean(prob_reduce*delta*gamma_t)
            return loss

        loss_policy, grads_policy = jax.value_and_grad(
            log_prob_loss)(self.policy_state.params)

        self.baseline_state = self.baseline_state.apply_gradients(
            grads=grads_baseline)
        self.policy_state = self.policy_state.apply_gradients(
            grads=grads_policy)
        return loss_policy+loss_baseline

    def train_single_step(self):
        state = self.env.reset(seed=self.seed)[0]
        key = self.rng

        episode_rewards = []
        episode_states = []
        episode_actions = []

        for _ in range(500):
            _, key = jax.random.split(key=key)
            probs = self.sample(np.expand_dims(state, axis=0))
            action = np.random.choice(self.num_actions, p=np.array(probs))
            episode_actions.append(action)
            episode_states.append(state)
            next_state, reward, done, truncated, info = self.env.step(action)
            episode_rewards.append(reward)
            state = next_state
            if done or truncated:
                break

        discounted_rewards = jnp.array([sum(reward * (GAMMA ** t) for t, reward in enumerate(episode_rewards[start:]))
                                        for start in range(len(episode_rewards))])
        gamma_t = jnp.array([sum(GAMMA ** t for t, reward in enumerate(episode_rewards[start:]))
                             for start in range(len(episode_rewards))])
        discounted_rewards = (
            discounted_rewards-discounted_rewards.mean())/(discounted_rewards.std()+1e-8)
        episode_states = jnp.array(episode_states)
        episode_actions = jnp.array(episode_actions)
        loss = self.update(episode_states, episode_actions,
                           discounted_rewards, gamma_t)
        return loss, np.sum(episode_rewards)

In [7]:
class Simulation:
    def __init__(self, env_name, algorithm) -> None:
        self.env_name = env_name
        self.algorithm = algorithm
        self.env = gym.make(self.env_name)
        self.num_actions = self.env.action_space.n
        self.observation_shape = self.env.observation_space.shape

    def train(self, episodes=1000):
        self.losses, self.rewards = np.zeros(
            (5, episodes)), np.zeros((5, episodes))

        for seed in range(5):
            self.algo = self.algorithm(
                self.env, self.num_actions, self.observation_shape, seed=seed)
            for ep in tqdm.tqdm(range(1, episodes+1)):
                loss, reward = self.algo.train_single_step()
                self.losses[seed][ep-1] = loss
                self.rewards[seed][ep-1] = reward

In [8]:
class DuelDQN_MeanAdvantage:
    def __init__(self):
        pass


class DuelDQN_MaxAdvantage:
    def __init__(self):
        pass

In [9]:
cartpole_reinforce = Simulation('CartPole-v1', algorithm=MC_Reinforce)
cartpole_reinforce.train()

100%|██████████| 1000/1000 [03:36<00:00,  4.63it/s]
100%|██████████| 1000/1000 [02:26<00:00,  6.85it/s]
100%|██████████| 1000/1000 [02:25<00:00,  6.85it/s]
100%|██████████| 1000/1000 [02:29<00:00,  6.69it/s]
100%|██████████| 1000/1000 [02:35<00:00,  6.44it/s]


In [10]:
cartpole_baseline = Simulation('CartPole-v1', algorithm=MC_Baseline)
cartpole_baseline.train()

100%|██████████| 1000/1000 [05:22<00:00,  3.10it/s]
100%|██████████| 1000/1000 [04:04<00:00,  4.09it/s]
100%|██████████| 1000/1000 [04:50<00:00,  3.45it/s]
100%|██████████| 1000/1000 [05:00<00:00,  3.33it/s]
100%|██████████| 1000/1000 [05:08<00:00,  3.24it/s]


In [11]:
acrobot_reinforce = Simulation('Acrobot-v1',algorithm=MC_Reinforce)
acrobot_reinforce.train()

100%|██████████| 1000/1000 [04:33<00:00,  3.65it/s]
100%|██████████| 1000/1000 [08:30<00:00,  1.96it/s]
100%|██████████| 1000/1000 [08:23<00:00,  1.99it/s]
100%|██████████| 1000/1000 [05:47<00:00,  2.87it/s]
 51%|█████     | 509/1000 [04:22<04:03,  2.02it/s]E0402 00:31:37.081436  358319 pjrt_stream_executor_client.cc:2804] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Failed to get module function:CUDA_ERROR_OUT_OF_MEMORY: out of memory
 51%|█████     | 509/1000 [04:22<04:13,  1.94it/s]


XlaRuntimeError: RESOURCE_EXHAUSTED: Failed to get module function:CUDA_ERROR_OUT_OF_MEMORY: out of memory

In [None]:
acrobot_baseline = Simulation('Acrobot-v1',algorithm=MC_Baseline)
acrobot_baseline.train()

  0%|          | 0/1000 [00:00<?, ?it/s]

100%|██████████| 1000/1000 [05:47<00:00,  2.88it/s]
100%|██████████| 1000/1000 [10:48<00:00,  1.54it/s]
100%|██████████| 1000/1000 [10:55<00:00,  1.53it/s]
100%|██████████| 1000/1000 [07:16<00:00,  2.29it/s]
100%|██████████| 1000/1000 [11:01<00:00,  1.51it/s]


In [None]:
rewards_cartpole_reinforce = cartpole_reinforce.rewards
rewards_cartpole_baseline = cartpole_baseline.rewards
rewards_acrobot_reinforce = acrobot_reinforce.rewards
rewards_acrobot_baseline = acrobot_baseline.rewards


mean_rcr = np.mean(rewards_cartpole_reinforce, axis=0)
mean_rcb = np.mean(rewards_cartpole_baseline, axis=0)
mean_rar = np.mean(rewards_acrobot_reinforce, axis=0)
mean_rab = np.mean(rewards_acrobot_baseline, axis=0)
mean_mat = [mean_rcr, mean_rcb, mean_rar, mean_rab]

std_rcr = np.std(rewards_cartpole_reinforce, axis=0)
std_rcb = np.std(rewards_cartpole_baseline, axis=0)
std_rar = np.std(rewards_acrobot_reinforce, axis=0)
std_rab = np.std(rewards_acrobot_baseline, axis=0)
std_mat = [std_rcr, std_rcb, std_rar, std_rab]

In [None]:
names = ['cartpole_reinforce', 'cartpole_baseline',
         'acrobot_reinforce', 'acrobot_baseline']

fig, ax = plt.subplots(2, 2, figsize=(12, 12))
for i in range(2):
    for j in range(2):
        mean = mean_mat[i*2+j]
        std = std_mat[i*2+j]
        x = range(len(mean))
        ax[i][j].plot(x, mean, color='blue', label='Mean')
        ax[i][j].plot(x, smooth_rewards(mean),
                      color='orange', label='smoothed')
        ax[i][j].fill_between(x, mean - std, mean + std, color='blue',
                              alpha=0.3, label='Mean ± Std')
        ax[i][j].set_xlabel('Steps')
        ax[i][j].set_ylabel('Rewards')
        ax[i][j].set_title(names[i*2+j])
        ax[i][j].legend()
        ax[i][j].grid(True)
plt.show()

NameError: name 'plt' is not defined