##Installation

In [None]:
#@title Installations { form-width: "30%" }

# Fixing the haiku problem
!pip install --upgrade pip
!pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

# Standard installs
!pip install dm-acme
!pip install dm-acme[reverb]
!pip install dm-acme[jax]
!pip install dm-acme[tf]
!pip install dm-acme[envs]
!pip install dm-env
!pip install dm-haiku
!pip install dm-tree
!pip install chex
!sudo apt-get install -y xvfb ffmpeg
!pip install imageio
!pip install gym
!pip install gym[classic_control]

!apt-get install x11-utils
!pip install pyglet

!pip install gym pyvirtualdisplay

from IPython.display import clear_output
clear_output()

Collecting pip
  Downloading pip-22.0.4-py3-none-any.whl (2.1 MB)
[K     |████████████████████████████████| 2.1 MB 5.4 MB/s 
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 21.1.3
    Uninstalling pip-21.1.3:
      Successfully uninstalled pip-21.1.3
Successfully installed pip-22.0.4
Looking in links: https://storage.googleapis.com/jax-releases/jax_releases.html
Collecting jaxlib==0.3.2+cuda11.cudnn82
  Downloading https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.2%2Bcuda11.cudnn82-cp37-none-manylinux2010_x86_64.whl (155.4 MB)
[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━[0m [32m107.8/155.4 MB[0m [31m112.2 MB/s[0m eta [36m0:00:01[0m

In [2]:
#@title Imports  { form-width: "30%" }

%matplotlib inline
import IPython
from IPython.display import HTML
from IPython import display as ipythondisplay

import acme
from acme import datasets
from acme import types
from acme import specs
from acme.wrappers import gym_wrapper
import base64
from base64 import b64encode
import chex
import collections
from collections import namedtuple
import dm_env
import enum
import functools
import gym
import haiku as hk
import imageio
import io
import itertools
import jax
from jax import tree_util
import optax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import multiprocessing as mp
import multiprocessing.connection
import numpy as np
import pandas as pd
import random
import reverb
import rlax
import time
import tree
from typing import *
import warnings
import pyglet
pyglet.options['search_local_libs'] = False
pyglet.options['shadow_window']=False
from pyglet.window import xlib
xlib._have_utf8 = False

from pyvirtualdisplay import Display
display = Display(visible=False, size=(1400, 900))
display.start()
 
np.set_printoptions(precision=3, suppress=1)

%matplotlib inline

##Environment

In [3]:
#@title PendulumEnv wrapper (have a look if interested) { form-width: "30%" }
class PendulumEnv(dm_env.Environment):
  def __init__(self, for_evaluation: bool) -> None:
    self._env = gym.make('Pendulum-v0')
    self._for_evaluation = for_evaluation
    if self._for_evaluation:
      self.screens = []

  def step(self, action: chex.ArrayNumpy) -> dm_env.TimeStep:
    new_obs, reward, done, _ = self._env.step(action)
    if self._for_evaluation:
      self.screens.append(self._env.render(mode='rgb_array'))
    if done:
      return dm_env.termination(reward, new_obs)
    return dm_env.transition(reward, new_obs)

  def reset(self) -> dm_env.TimeStep:
    obs = self._env.reset()
    if self._for_evaluation:
      self.screens.append(self._env.render(mode='rgb_array'))
    return dm_env.restart(obs)

  #The observation is a `ndarray` with shape `(3,)` representing the x-y coordinates of the pendulum's free end and its angular velocity.
  def observation_spec(self) -> specs.BoundedArray:
    return specs.BoundedArray(shape=(3,), minimum=-8., maximum=8., dtype=np.float32)

  #The action is a `ndarray` with shape `(1,)` representing the torque applied to free end of the pendulum.
  def action_spec(self) -> specs.BoundedArray:
    return specs.BoundedArray(shape=(1,), minimum=-2., maximum=2., dtype=np.float32)

  def close(self) -> None:
    self._env.close()

In [4]:
#@title Simple interaction loop with Random Agent { form-width: "30%" }

#@title **[Implement]** RandomAgent { form-width: "30%" }
import tree

import abc

# Encapsulate a trajectory. Temporally, a trajectory unrolls as
# o_0, a_0, r_0, d_0, ..., o_{T-1}, a_{T-1}, r_{T-1}, d_{T-1}.
@chex.dataclass
class Trajectory:
  observations: types.NestedArray  # [T, B, ...]
  actions: types.NestedArray  # [T, B, ...]
  rewards: chex.ArrayNumpy  # [T, B]
  dones: chex.ArrayNumpy  # [T, B]
  discounts: chex.ArrayNumpy # [T, B]

class RandomAgent():
  def __init__(self, environment_spec: specs.EnvironmentSpec) -> None:
    self._action_spec = environment_spec.actions

  def batched_actor_step(self, observation: types.NestedArray) -> types.NestedArray:
    batch_size = tree.flatten(observation)[0].shape[0]
    return np.random.randn(batch_size, *self._action_spec.shape)
    
  def learner_step(self, trajectory: Trajectory) -> Mapping[str, chex.ArrayNumpy]:
    # Skip it
    return dict()

#@title **[Read and understand]** Dummy interaction loop { form-width: "30%" }
def simple_interaction_loop(agent: RandomAgent, environment: dm_env.Environment, max_num_steps: int = 5000) -> None:
  ts = environment.reset()
  for _ in range(max_num_steps):
    if ts.last():
      break

    batched_observation = tree.map_structure(lambda x: x[None], ts.observation)
    action = agent.batched_actor_step(batched_observation)[0]  # batch size = 1
    ts = environment.step(action)

pendulum_environment = PendulumEnv(for_evaluation=True)  #### WARNING, I changed the env
pendulum_environment_spec = acme.make_environment_spec(pendulum_environment)
pendulum_random_agent = RandomAgent(pendulum_environment_spec)

simple_interaction_loop(pendulum_random_agent, pendulum_environment, 5000)

#@title Video display facility { form-width: "30%" }
def display_video(frames, filename='temp.mp4', frame_repeat=1):
  """Save and display video."""
  # Write video
  with imageio.get_writer(filename, fps=60) as video:
    for frame in frames:
      for _ in range(frame_repeat):
        video.append_data(frame)
  # Read video and display the video
  video = open(filename, 'rb').read()
  b64_video = base64.b64encode(video)
  video_tag = ('<video  width="320" height="240" controls alt="test" '
               'src="data:video/mp4;base64,{0}">').format(b64_video.decode())
  return IPython.display.HTML(video_tag)

#@title Display a video of your random agent { form-width: "30%" }
display_video(np.stack(pendulum_environment.screens, axis=0))



In [5]:
#@title **[Solution]** Uniform Replay Buffer { form-width: "30%" }


Transition = collections.namedtuple("Transition", 
                                    field_names=["obs_tm1", "action_tm1", "reward_t", "discount_t", "obs_t", "done"])

class ReplayBuffer:
    """Fixed-size buffer to store transition tuples."""

    def __init__(self, buffer_capacity: int):
        """Initialize a ReplayBuffer object.
        Args:
            batch_size (int): size of each training batch
        """
        self._memory = list()
        self._maxlen = buffer_capacity

    def add(self, obs_tm1, action_tm1, reward_t, discount_t, obs_t, done):
        """Add a new transition to memory."""
        if len(self._memory) >= self._maxlen: 
          self._memory.pop(0)  # remove first elem (oldest)

        transition = Transition(
            obs_tm1=obs_tm1,
            action_tm1=action_tm1,
            reward_t=reward_t,
            discount_t=discount_t,
            obs_t=obs_t,
            done=done)
        
        # convert every data into jnp array
        transition = jax.tree_map(jnp.array, transition)

        self._memory.append(transition)

    def sample(self):
        """Randomly sample a transition from memory."""
        assert self._memory, 'replay buffer is unfilled'
        transition_idx = np.random.randint(0, len(self._memory))
        transition = self._memory.pop(transition_idx)
        
        return transition


In [6]:
#@title **[Solution]** Uniform Replay Buffer with Batch{ form-width: "30%" }

class BatchedReplayBuffer(ReplayBuffer):

      def sample_batch(self, batch_size):
        """Randomly sample a batch of experiences from memory."""
        assert len(self._memory) > batch_size, 'Insuficient number of transitions in replay buffer ' + str(len(self._memory))
        all_transitions = [self.sample() for _ in range(batch_size)] 

        stacked_transitions = []
        for i, _ in enumerate(all_transitions[0]):
          arrays = [t[i] for t in all_transitions]
          arrays = jnp.stack(arrays, axis=0)
          stacked_transitions.append(arrays)

        return Transition(*stacked_transitions)

In [7]:
num_episodes = 5

pendulum_environment = PendulumEnv(for_evaluation=True)  #### WARNING, I changed the env
pendulum_environment_spec = acme.make_environment_spec(pendulum_environment)
pendulum_random_agent = RandomAgent(pendulum_environment_spec)

replay_buffer = BatchedReplayBuffer(1000)

random_action = 0

for episode in range(num_episodes):

  # Reset any counts and start the environment.
  timestep = pendulum_environment.reset()

  while not timestep.last():
    
    batched_observation = tree.map_structure(lambda x: x[None], timestep.observation)
    action = pendulum_random_agent.batched_actor_step(batched_observation)[0]  # batch size = 1
    random_action =  action
    timestep_tm1 = timestep
    timestep = pendulum_environment.step(action)

    replay_buffer.add(obs_tm1=timestep_tm1.observation, 
                      action_tm1=action,
                      reward_t=timestep.reward,
                      discount_t=timestep.discount,
                      obs_t=timestep.observation,
                      done=timestep.last())




In [8]:
class ValueNetwork(hk.Module):
  def __init__(self, output_sizes: Sequence[int], name: Optional[str] = None) -> None:
    super().__init__(name=name)
    self._output_sizes = output_sizes

  def __call__(self, observations: chex.Array) -> chex.Array:
    h = observations

    for i, o in enumerate(self._output_sizes):
      h = hk.Linear(o)(h)
      h = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(h)
      h = jax.nn.relu(h)
    return hk.Linear(1)(h)[..., 0]

class SoftQNetwork(hk.Module):
  def __init__(self, output_sizes: Sequence[int], name: Optional[str] = None) -> None:
    super().__init__(name=name)
    self._output_sizes = output_sizes

  def __call__(self, observations: chex.Array, actions: chex.Array) -> chex.Array:
    h = jnp.concatenate([observations, actions], axis=1)

    for i, o in enumerate(self._output_sizes):
      h = hk.Linear(o)(h)
      h = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(h)
      h = jax.nn.relu(h)
    return hk.Linear(1)(h)[..., 0]

class PolicyNetwork(hk.Module):
  def __init__(self, output_sizes: Sequence[int], action_spec: specs.BoundedArray, name: Optional[str] = None) -> None:
    super().__init__(name=name)
    self._output_sizes = output_sizes
    self._action_spec = action_spec

  def __call__(self, x: chex.Array, ) -> Tuple[chex.Array, chex.Array]:
    action_shape = self._action_spec.shape
    action_dims = np.prod(action_shape)
    h = x
    for i, o in enumerate(self._output_sizes):
      h = hk.Linear(o)(h)
      h = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(h)
      h = jax.nn.relu(h)
    h = hk.Linear(2 * action_dims)(h)
    mu, pre_sigma = jnp.split(h, 2, axis=-1)
    sigma = jax.nn.softplus(pre_sigma)
    return hk.Reshape(action_shape)(.1 * mu), hk.Reshape(action_shape)(.1 * sigma)

In [10]:
def value_net(observations: types.NestedArray):
  return ValueNetwork([256, 256])(observations)

def soft_q_net(observations: types.NestedArray, actions: types.NestedArray):
  return SoftQNetwork([256, 256])(observations, actions)

def policy_net(observations: types.NestedArray):
  return PolicyNetwork([256, 256], pendulum_environment.action_spec())(observations)

value_net_init, value_net_apply = hk.without_apply_rng(hk.transform(value_net))
policy_net_init, policy_net_apply = hk.without_apply_rng(hk.transform(policy_net))
soft_q_net_init, soft_q_net_apply = hk.without_apply_rng(hk.transform(soft_q_net))

init_rng = jax.random.PRNGKey(0)
init_timestep = timestep = pendulum_environment.reset()

value_online_params = value_net_init(init_rng, init_timestep.observation[None])  # None -> to add a batch dimension to the observation
value_target_params = value_online_params

policy_params = policy_net_init(init_rng, init_timestep.observation[None])  # None -> to add a batch dimension to the observation

random_action = pendulum_random_agent.batched_actor_step(init_timestep.observation[None])
soft_q1_params = soft_q_net_init(init_rng, init_timestep.observation[None], random_action)  # None -> to add a batch dimension to the observation
soft_q2_params = soft_q1_params

value_net_apply_jitted = jax.jit(value_net_apply)
policy_net_apply_jitted = jax.jit(policy_net_apply)
soft_q_net_apply_jitted = jax.jit(soft_q_net_apply)

batched_observation = tree.map_structure(lambda x: x[None], timestep.observation)
action = pendulum_random_agent.batched_actor_step(batched_observation)[0]  # batch size = 1
timestep = pendulum_environment.step(action)

samples = replay_buffer.sample_batch(32)
#print(samples.action_tm1.shape)
#print(samples.obs_tm1.shape)
#print(jnp.concatenate([samples.obs_tm1, samples.action_tm1], axis=1).shape)
predicted_q_value = soft_q_net_apply_jitted(soft_q1_params, samples.obs_tm1, samples.action_tm1)
predicted_value = value_net_apply_jitted(value_target_params, samples.obs_tm1)
mu, sigma = policy_net_apply_jitted(policy_params, samples.obs_tm1)
#qa_tm1 = jax.vmap(lambda q, a: q[a])(predicted_q_value, samples.action_tm1)
print(predicted_q_value)
print(predicted_value)
print(mu, sigma)

  lax_internal._check_user_dtype_supported(dtype, "zeros")


[ 0.47   0.187  0.287  0.706  0.593  0.132  0.318  0.272  0.357  0.2
  1.072  0.075  0.062  0.299  0.099  0.318  0.066 -0.401  0.6    1.053
  0.145  0.121  0.272  0.055  0.208  0.277  0.554  0.349  0.282  0.078
  0.258 -0.256]
[ 0.299 -0.09   0.842  0.254  0.61  -0.182 -0.025 -0.155  0.01  -0.143
  0.452  0.989  0.968 -0.176  0.968  0.813  0.883 -0.133 -0.04   0.546
  1.059  0.966  0.901  0.882  0.046 -0.161  0.392  0.636  0.426  1.002
  0.07   0.181]
[[ 0.047]
 [-0.004]
 [-0.024]
 [ 0.037]
 [ 0.057]
 [-0.004]
 [ 0.014]
 [-0.005]
 [ 0.006]
 [-0.003]
 [-0.03 ]
 [-0.02 ]
 [-0.021]
 [-0.005]
 [-0.02 ]
 [-0.023]
 [-0.013]
 [ 0.004]
 [ 0.054]
 [ 0.008]
 [-0.025]
 [-0.012]
 [-0.02 ]
 [-0.03 ]
 [ 0.016]
 [-0.005]
 [-0.019]
 [-0.037]
 [-0.042]
 [-0.021]
 [ 0.024]
 [-0.014]] [[0.13 ]
 [0.13 ]
 [0.09 ]
 [0.137]
 [0.148]
 [0.118]
 [0.103]
 [0.121]
 [0.128]
 [0.115]
 [0.159]
 [0.103]
 [0.102]
 [0.121]
 [0.106]
 [0.094]
 [0.115]
 [0.106]
 [0.092]
 [0.108]
 [0.096]
 [0.133]
 [0.102]
 [0.096]
 [0.113

In [11]:
rng = jax.random.PRNGKey(0)
#rlax.gaussian_diagonal().sample(rng, mu, sigma) + samples.action_tm1

In [12]:
def q_loss(q1_params, q2_params, target_value_params, obs_tm1, action_tm1, reward_t, discount_t, obs_t, done):
  target_value    = value_net_apply_jitted(target_value_params, obs_t)
  target_q_value  = jax.lax.stop_gradient(target_value)
  target_q_value = (1. - done[..., None]) * target_value
  target_q_value  = reward_t + discount_t * target_value

  predicted_q1_value = soft_q_net_apply_jitted(q1_params, obs_tm1, action_tm1)
  predicted_q2_value = soft_q_net_apply_jitted(q2_params, obs_tm1, action_tm1)

  td_error1 = target_q_value - predicted_q1_value
  td_error2 = target_q_value - predicted_q2_value

  q_loss1 = 0.5 * jnp.square(td_error1)
  q_loss2 = 0.5 * jnp.square(td_error2)
  
  q_loss1 = jnp.mean(q_loss1)
  q_loss2 = jnp.mean(q_loss2)

  return q_loss1 + q_loss2

rng = jax.random.PRNGKey(0)

def policy_loss(policy_params, q_params, obs_tm1, action_tm1, reward_t, discount_t, obs_t, done):
  mu, sigma = policy_net_apply_jitted(policy_params, obs_tm1)
  new_actions = rlax.gaussian_diagonal().sample(rng, mu, sigma)
  predicted_new_q_value = soft_q_net_apply_jitted(q_params, obs_tm1, new_actions)
  action_log_probs = rlax.gaussian_diagonal().logprob(action_tm1, mu, sigma)

  policy_loss = jnp.mean(action_log_probs - predicted_new_q_value)

  return policy_loss

def v_loss(online_value_params, policy_params, q1_params, q2_params, obs_tm1, action_tm1, reward_t, discount_t, obs_t, done):
  mu, sigma = policy_net_apply_jitted(policy_params, obs_tm1)
  action_log_probs = rlax.gaussian_diagonal().logprob(action_tm1, mu, sigma)
  new_actions = rlax.gaussian_diagonal().sample(rng, mu, sigma)

  predicted_new_q1_value = soft_q_net_apply_jitted(q1_params, obs_tm1, new_actions)
  predicted_new_q2_value = soft_q_net_apply_jitted(q2_params, obs_tm1, new_actions)
  predicted_new_q_value = jax.lax.min(predicted_new_q1_value, predicted_new_q2_value)
  
  
  target_value_func = predicted_new_q_value - action_log_probs
  target_value_func = jax.lax.stop_gradient(target_value_func)

  predicted_value = value_net_apply_jitted(online_value_params, obs_t)

  error = predicted_value - target_value_func

  # Compute the L2 error in expectation
  loss = 0.5 * jnp.square(error)
  loss = jnp.mean(loss)

  return loss


In [13]:
def pure_function(y: chex.Array, x: chex.Array) -> chex.Array:
  return x**2 + y**3

grad_pure = jax.value_and_grad(pure_function, argnums=[0,1])
x = 3.
y = 2.
print(f'Value at point x={x}, f(x)={pure_function(y,x)}, grad_f(x)={grad_pure(y,x)}')
#Value at point x=3.0, f(x)=17.0, grad_f(x)=6.0
#Value at point x=3.0, f(x)=17.0, grad_f(x)=12.0
#Value at point x=3.0, f(x)=17.0, grad_f(x)=(DeviceArray(12., dtype=float32, weak_type=True), DeviceArray(6., dtype=float32, weak_type=True))

Value at point x=3.0, f(x)=17.0, grad_f(x)=(DeviceArray(17., dtype=float32, weak_type=True), (DeviceArray(12., dtype=float32, weak_type=True), DeviceArray(6., dtype=float32, weak_type=True)))


In [14]:
optimizer = optax.adam(3e-4)
q1_grad_fn = jax.value_and_grad(q_loss, argnums=0)
q2_grad_fn = jax.value_and_grad(q_loss, argnums=1)
v_grad_fn = jax.value_and_grad(v_loss)
policy_grad_fn = jax.value_and_grad(policy_loss)

def update_fn(value_target_params, value_online_params, soft_q1_params, soft_q2_params, policy_params, q1_opt_state, q2_opt_state, v_opt_state, p_opt_state, samples):
  loss_q1, grad_q1 = q1_grad_fn(soft_q1_params, soft_q2_params, value_target_params, *samples)
  loss_q2, grad_q2 = q2_grad_fn(soft_q1_params, soft_q2_params, value_target_params, *samples)
  updates, new_q1_opt_state = optimizer_q.update(grad_q1, q1_opt_state)
  new_soft_q1_params = optax.apply_updates(soft_q1_params, updates)
  updates, new_q2_opt_state = optimizer_q.update(grad_q2, q2_opt_state)
  new_soft_q2_params = optax.apply_updates(soft_q2_params, updates)

  # TODO fix here add second q-network
  loss_pi, grad_pi = policy_grad_fn(policy_params, soft_q1_params, *samples)
  updates, new_p_opt_state = optimizer_q.update(grad_pi, p_opt_state)
  new_policy_params = optax.apply_updates(policy_params, updates)

  loss_v, grad_v = v_grad_fn(value_online_params, policy_params, soft_q1_params, soft_q2_params, *samples)
  updates, new_v_opt_state = optimizer_q.update(grad_v, v_opt_state)
  new_value_online_params = optax.apply_updates(value_online_params, updates)

  ### Target network update with polyak averaging
  new_value_target_params = jax.tree_multimap(lambda x, y: x + 0.9 * (y - x),
                                    value_target_params, new_value_online_params)
  
  return new_value_target_params, new_value_online_params, new_soft_q1_params, new_soft_q2_params, new_policy_params, new_v_opt_state, new_q1_opt_state, new_q2_opt_state, new_p_opt_state

def update_fn_old(value_target_params, value_online_params, soft_q_params, policy_params, q_opt_state, v_opt_state, p_opt_state, samples):
  # Compute gradient
  #gradients = grad_fn(online_params, target_params, *samples)
  q_gradients = q_grad_fn(value_target_params, soft_q_params, *samples)
  v_gradients = v_grad_fn(value_online_params, policy_params, soft_q_params, *samples)
  p_gradients = policy_grad_fn(soft_q_params, *samples)

  # Apply gradients
  q_updates, q_new_opt_state = optimizer.update(q_gradients, q_opt_state)
  v_updates, v_new_opt_state = optimizer.update(v_gradients, v_opt_state)
  p_updates, p_new_opt_state = optimizer.update(p_gradients, p_opt_state)
  
  new_soft_q_params = optax.apply_updates(soft_q_params, q_updates)
  new_soft_q_params = optax.apply_updates(soft_q_params, q_updates)
  new_soft_q_params = optax.apply_updates(soft_q_params, q_updates)
  new_soft_q_params = optax.apply_updates(soft_q_params, q_updates)


In [15]:
optimizer_q = optax.adam(3e-4)
optimizer_v = optax.adam(3e-4)
optimizer_pi = optax.adam(3e-4)

q_grad_fn = jax.grad(q_loss)
v_grad_fn = jax.grad(v_loss)
q1_grad_fn = jax.value_and_grad(q_loss, argnums=0)
q2_grad_fn = jax.value_and_grad(q_loss, argnums=1)

q1_grad = q1_grad_fn(soft_q1_params, soft_q2_params, value_target_params, *samples)
q2_grad = q2_grad_fn(soft_q1_params, soft_q2_params, value_target_params, *samples)
# TODO are both gradients the same? check values during iteration

In [None]:
# Training Options
num_training_loop = 500
num_acting_steps = 64
replay_buffer_size= 1000
epsilon = 0.1
batch_size = 32
target_ema= 0.99
gamma = .9

# Reset Network params
init_timestep = pendulum_environment.reset()
#online_params = network_init(rng, init_timestep.observation[None]) 
#target_params = online_params

# Create Optimizer
optimizer = optax.adam(3e-4)
v_opt_state = optimizer.init(value_online_params)
q1_opt_state = optimizer.init(soft_q1_params)
q2_opt_state = optimizer.init(soft_q2_params)
p_opt_state = optimizer.init(policy_params)
#opt_state = optimizer.init(online_params)

# Define the update function
q1_grad_fn = jax.value_and_grad(q_loss, argnums=0)
q2_grad_fn = jax.value_and_grad(q_loss, argnums=1)
v_grad_fn = jax.value_and_grad(v_loss)
policy_grad_fn = jax.value_and_grad(policy_loss)
rng = jax.random.PRNGKey(0)

update_fn_jitted = jax.jit(update_fn)

# Create Replay buffer
replay_buffer = BatchedReplayBuffer(replay_buffer_size)


########
### Vanilla DQN
#######

rewards = collections.deque(maxlen=1000)
q_values = collections.deque(maxlen=1000)
actions = collections.deque(maxlen=1000)

timestep = pendulum_environment.reset()
action = random_action  # dummy action


for t in range(num_training_loop):

  ### Acting

  # Act in the environment and store it in the replay buffer
  for _ in range(num_acting_steps):

    if timestep.last():
      rewards.append(timestep.reward)
      timestep = pendulum_environment.reset()
      action = action = random_action  # dummy action
      
    # Sample action according policy
    mu, sigma = policy_net_apply_jitted(policy_params, timestep.observation[None][0])
    new_actions = rlax.gaussian_diagonal().sample(rng, mu, sigma)

    # Store values
    #q_values.append(q_value)
    #actions.append(np.eye(environment.action_spec().num_values)[action])

    # Step into environment
    timestep_tm1 = timestep
    timestep = pendulum_environment.step(new_actions)

    # Store in replay buffer
    replay_buffer.add(obs_tm1=timestep_tm1.observation, 
                      action_tm1=new_actions,
                      reward_t=timestep.reward,
                      discount_t=gamma * timestep.discount,
                      obs_t=timestep.observation,
                      done=timestep.last())
    
  ### Learning

  # updates
  samples = replay_buffer.sample_batch(batch_size)
  value_target_params, value_online_params, soft_q1_params, soft_q2_params, policy_params, \
  v_opt_state, q1_opt_state, q2_opt_state, p_opt_state = update_fn(
      value_target_params, value_online_params, soft_q1_params, soft_q2_params, policy_params, q1_opt_state, q2_opt_state, v_opt_state, p_opt_state, samples)


  if t % 100 == 0:
    print(t, "\t reward", np.mean(rewards), "\t actions")



  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


0 	 reward nan 	 actions
50 	 reward -6.7860456 	 actions
100 	 reward -6.5060735 	 actions
150 	 reward -6.9153953 	 actions
