In [7]:
import gym
from inari.cdqn import CDQN
import jax
from flax import linen as nn
import jax.numpy as jnp
from tqdm import tqdm
import numpy as np

In [8]:
class CDQNNetwork(nn.Module):
    num_layers: int = 2
    hidden_units: int = 16

    @nn.compact
    def __call__(self, state: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray:
        kernel_initializer = jax.nn.initializers.glorot_uniform()

        # Preprocess inputs
        a = action.reshape(-1)  # flatten
        s = state.astype(jnp.float32)
        s = state.reshape(-1)  # flatten
        x = jnp.concatenate((s, a))

        for _ in range(self.num_layers):
            x = nn.Dense(features=self.hidden_units,
                         kernel_init=kernel_initializer)(x)
            x = nn.relu(x)

        return nn.Dense(features=1, kernel_init=kernel_initializer)(x)

In [9]:
seed = 0
env = gym.make('MountainCarContinuous-v0')
obs_shape = env.observation_space.shape
act_shape = env.action_space.shape
rng = jax.random.PRNGKey(seed)
cdqn = CDQN(CDQNNetwork, obs_shape, act_shape, base_dir='./tests/')



In [10]:
obs = env.reset()
done = False
for epoch in range(10):
    epoch_ep_rewards = []
    cum_ep_reward = 0
    for step in tqdm(range(30000)):
        rng, local_rng = jax.random.split(rng)
        act, est_q = cdqn.select_action(obs, local_rng)
        cdqn.network_def.apply(cdqn.online_params, obs, np.)[0]
        obs, rew, done, info = env.step(act)
        cost = 0
        cum_ep_reward += rew
        cdqn.store_transition(obs, np.array(act), rew, done)
        cdqn.train_step()
        if done:
            #with cdqn.summary_writer.as_default():
            #    tf.summary.scalar("ep_reward",
            #                        cum_ep_reward,
            #                        step=epoch * 30000 + step)
            print(cum_ep_reward)
            obs = env.reset()
            epoch_ep_rewards.append(cum_ep_reward)
            cum_ep_reward = 0
    print("Mean episode reward:", np.mean(epoch_ep_rewards))

  3%|▎         | 1000/30000 [02:49<1:15:43,  6.38it/s]

-96.8799387493396


  5%|▍         | 1480/30000 [04:12<1:21:12,  5.85it/s]


KeyboardInterrupt: 