##Installation

In [2]:
#@title Installations { form-width: "30%" }

# Fixing the haiku problem
!pip install --upgrade pip
!pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

# Standard installs
!pip install dm-acme
!pip install dm-acme[reverb]
!pip install dm-acme[jax]
!pip install dm-acme[tf]
!pip install dm-acme[envs]
!pip install dm-env
!pip install dm-haiku
!pip install dm-tree
!pip install chex
!sudo apt-get install -y xvfb ffmpeg
!pip install imageio
!pip install gym
!pip install gym[classic_control]

!apt-get install x11-utils
!pip install pyglet

!pip install gym pyvirtualdisplay

from IPython.display import clear_output
clear_output()

In [3]:
#@title Imports  { form-width: "30%" }

%matplotlib inline
import IPython
from IPython.display import HTML
from IPython import display as ipythondisplay

import acme
from acme import datasets
from acme import types
from acme import specs
from acme.wrappers import gym_wrapper
import base64
from base64 import b64encode
import chex
import collections
from collections import namedtuple
import dm_env
import enum
import functools
import gym
import haiku as hk
import imageio
import io
import itertools
import jax
from jax import tree_util
import optax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import multiprocessing as mp
import multiprocessing.connection
import numpy as np
import pandas as pd
import random
import reverb
import rlax
import time
import tree
from typing import *
import warnings
import pyglet
pyglet.options['search_local_libs'] = False
pyglet.options['shadow_window']=False
from pyglet.window import xlib
xlib._have_utf8 = False

from pyvirtualdisplay import Display
display = Display(visible=False, size=(1400, 900))
display.start()
 
np.set_printoptions(precision=3, suppress=1)

%matplotlib inline

##Environment

In [4]:
#@title PendulumEnv wrapper (have a look if interested) { form-width: "30%" }
class PendulumEnv(dm_env.Environment):
  def __init__(self, for_evaluation: bool) -> None:
    self._env = gym.make('Pendulum-v0')
    self._for_evaluation = for_evaluation
    if self._for_evaluation:
      self.screens = []

  def step(self, action: chex.ArrayNumpy) -> dm_env.TimeStep:
    new_obs, reward, done, _ = self._env.step(action)
    if self._for_evaluation:
      self.screens.append(self._env.render(mode='rgb_array'))
    if done:
      return dm_env.termination(reward, new_obs)
    return dm_env.transition(reward, new_obs)

  def reset(self) -> dm_env.TimeStep:
    obs = self._env.reset()
    if self._for_evaluation:
      self.screens.append(self._env.render(mode='rgb_array'))
    return dm_env.restart(obs)

  #The observation is a `ndarray` with shape `(3,)` representing the x-y coordinates of the pendulum's free end and its angular velocity.
  def observation_spec(self) -> specs.BoundedArray:
    return specs.BoundedArray(shape=(3,), minimum=-8., maximum=8., dtype=np.float32)

  #The action is a `ndarray` with shape `(1,)` representing the torque applied to free end of the pendulum.
  def action_spec(self) -> specs.BoundedArray:
    return specs.BoundedArray(shape=(1,), minimum=-2., maximum=2., dtype=np.float32)

  def close(self) -> None:
    self._env.close()

In [5]:
#@title Simple interaction loop with Random Agent { form-width: "30%" }

#@title **[Implement]** RandomAgent { form-width: "30%" }
import tree

import abc

# Encapsulate a trajectory. Temporally, a trajectory unrolls as
# o_0, a_0, r_0, d_0, ..., o_{T-1}, a_{T-1}, r_{T-1}, d_{T-1}.
@chex.dataclass
class Trajectory:
  observations: types.NestedArray  # [T, B, ...]
  actions: types.NestedArray  # [T, B, ...]
  rewards: chex.ArrayNumpy  # [T, B]
  dones: chex.ArrayNumpy  # [T, B]
  discounts: chex.ArrayNumpy # [T, B]

class RandomAgent():
  def __init__(self, environment_spec: specs.EnvironmentSpec) -> None:
    self._action_spec = environment_spec.actions

  def batched_actor_step(self, observation: types.NestedArray) -> types.NestedArray:
    batch_size = tree.flatten(observation)[0].shape[0]
    return np.random.randn(batch_size, *self._action_spec.shape)
    
  def learner_step(self, trajectory: Trajectory) -> Mapping[str, chex.ArrayNumpy]:
    # Skip it
    return dict()

#@title **[Read and understand]** Dummy interaction loop { form-width: "30%" }
def simple_interaction_loop(agent: RandomAgent, environment: dm_env.Environment, max_num_steps: int = 5000) -> None:
  ts = environment.reset()
  for _ in range(max_num_steps):
    if ts.last():
      break

    batched_observation = tree.map_structure(lambda x: x[None], ts.observation)
    action = agent.batched_actor_step(batched_observation)[0]  # batch size = 1
    ts = environment.step(action)

pendulum_environment = PendulumEnv(for_evaluation=True)  #### WARNING, I changed the env
pendulum_environment_spec = acme.make_environment_spec(pendulum_environment)
pendulum_random_agent = RandomAgent(pendulum_environment_spec)

simple_interaction_loop(pendulum_random_agent, pendulum_environment, 5000)

#@title Video display facility { form-width: "30%" }
def display_video(frames, filename='temp.mp4', frame_repeat=1):
  """Save and display video."""
  # Write video
  with imageio.get_writer(filename, fps=60) as video:
    for frame in frames:
      for _ in range(frame_repeat):
        video.append_data(frame)
  # Read video and display the video
  video = open(filename, 'rb').read()
  b64_video = base64.b64encode(video)
  video_tag = ('<video  width="320" height="240" controls alt="test" '
               'src="data:video/mp4;base64,{0}">').format(b64_video.decode())
  return IPython.display.HTML(video_tag)

#@title Display a video of your random agent { form-width: "30%" }
display_video(np.stack(pendulum_environment.screens, axis=0))



In [6]:
#@title **[Solution]** Uniform Replay Buffer { form-width: "30%" }


Transition = collections.namedtuple("Transition", 
                                    field_names=["obs_tm1", "action_tm1", "reward_t", "discount_t", "obs_t", "done"])

class ReplayBuffer:
    """Fixed-size buffer to store transition tuples."""

    def __init__(self, buffer_capacity: int):
        """Initialize a ReplayBuffer object.
        Args:
            batch_size (int): size of each training batch
        """
        self._memory = list()
        self._maxlen = buffer_capacity

    def add(self, obs_tm1, action_tm1, reward_t, discount_t, obs_t, done):
        """Add a new transition to memory."""
        if len(self._memory) >= self._maxlen: 
          self._memory.pop(0)  # remove first elem (oldest)

        transition = Transition(
            obs_tm1=obs_tm1,
            action_tm1=action_tm1,
            reward_t=reward_t,
            discount_t=discount_t,
            obs_t=obs_t,
            done=done)
        
        # convert every data into jnp array
        transition = jax.tree_map(jnp.array, transition)

        self._memory.append(transition)

    def sample(self):
        """Randomly sample a transition from memory."""
        assert self._memory, 'replay buffer is unfilled'
        transition_idx = np.random.randint(0, len(self._memory))
        transition = self._memory.pop(transition_idx)
        
        return transition


In [7]:
num_episodes = 5

pendulum_environment = PendulumEnv(for_evaluation=True)  #### WARNING, I changed the env
pendulum_environment_spec = acme.make_environment_spec(pendulum_environment)
pendulum_random_agent = RandomAgent(pendulum_environment_spec)

replay_buffer = ReplayBuffer(10)

for episode in range(num_episodes):

  # Reset any counts and start the environment.
  timestep = pendulum_environment.reset()

  while not timestep.last():
    
    batched_observation = tree.map_structure(lambda x: x[None], timestep.observation)
    action = pendulum_random_agent.batched_actor_step(batched_observation)[0]  # batch size = 1
    replay_buffer.add(obs_tm1=timestep.observation, 
                      action_tm1=action,
                      reward_t=timestep.reward,
                      discount_t=timestep.discount,
                      obs_t=timestep.observation,
                      done=timestep.last())
    timestep = pendulum_environment.step(action)




In [8]:
#@title **[Solution]** Uniform Replay Buffer with Batch{ form-width: "30%" }

class BatchedReplayBuffer(ReplayBuffer):

      def sample_batch(self, batch_size):
        """Randomly sample a batch of experiences from memory."""
        assert len(self._memory) > batch_size, 'Insuficient number of transitions in replay buffer'
        all_transitions = [self.sample() for _ in range(batch_size)] 

        stacked_transitions = []
        for i, _ in enumerate(all_transitions[0]):
          arrays = [t[i] for t in all_transitions]
          arrays = jnp.stack(arrays, axis=0)
          stacked_transitions.append(arrays)

        return Transition(*stacked_transitions)

In [25]:
class ValueNetwork(hk.Module):
  def __init__(self, output_sizes: Sequence[int], name: Optional[str] = None) -> None:
    super().__init__(name=name)
    self._output_sizes = output_sizes

  def __call__(self, x: chex.Array) -> chex.Array:
    h = x

    for i, o in enumerate(self._output_sizes):
      h = hk.Linear(o)(h)
      h = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(h)
      h = jax.nn.relu(h)
    return hk.Linear(1)(h)[..., 0]

class SoftQNetwork(hk.Module):
  def __init__(self, output_sizes: Sequence[int], name: Optional[str] = None) -> None:
    super().__init__(name=name)
    self._output_sizes = output_sizes

  def __call__(self, x: chex.Array) -> chex.Array:
    h = x

    for i, o in enumerate(self._output_sizes):
      h = hk.Linear(o)(h)
      h = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(h)
      h = jax.nn.relu(h)
    return hk.Linear(1)(h)[..., 0]

class PolicyNetwork(hk.Module):
  def __init__(self, output_sizes: Sequence[int], action_spec: specs.BoundedArray, name: Optional[str] = None) -> None:
    super().__init__(name=name)
    self._output_sizes = output_sizes
    self._action_spec = action_spec

  def __call__(self, x: chex.Array, ) -> Tuple[chex.Array, chex.Array]:
    action_shape = self._action_spec.shape
    action_dims = np.prod(action_shape)
    h = x
    for i, o in enumerate(self._output_sizes):
      h = hk.Linear(o)(h)
      h = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(h)
      h = jax.nn.relu(h)
    h = hk.Linear(2 * action_dims)(h)
    mu, pre_sigma = jnp.split(h, 2, axis=-1)
    sigma = jax.nn.softplus(pre_sigma)
    return hk.Reshape(action_shape)(.1 * mu), hk.Reshape(action_shape)(.1 * sigma)

In [28]:
def value_net(observations: types.NestedArray):
  return ValueNetwork([256, 256])(observations)

def soft_q_net(observations: types.NestedArray):
  return SoftQNetwork([256, 256])(observations)

def policy_net(observations: types.NestedArray):
  return PolicyNetwork([256, 256], pendulum_environment.action_spec())(observations)

value_net_init, value_net_apply = hk.without_apply_rng(hk.transform(value_net))
target_value_net_init, target_value_net_apply = hk.without_apply_rng(hk.transform(value_net))
policy_net_init, policy_net_apply = hk.without_apply_rng(hk.transform(policy_net))
soft_q_net_init, soft_q_net_apply = hk.without_apply_rng(hk.transform(soft_q_net))

init_rng = jax.random.PRNGKey(0)
init_timestep = timestep = pendulum_environment.reset()

target_value_online_params = target_value_net_init(init_rng, init_timestep.observation[None])  # None -> to add a batch dimension to the observation
target_value_target_params = target_value_online_params

value_online_params = value_net_init(init_rng, init_timestep.observation[None])  # None -> to add a batch dimension to the observation
value_target_params = value_online_params

policy_online_params = policy_net_init(init_rng, init_timestep.observation[None])  # None -> to add a batch dimension to the observation
policy_target_params = policy_online_params

soft_q_online_params = soft_q_net_init(init_rng, init_timestep.observation[None])  # None -> to add a batch dimension to the observation
soft_q_target_params = soft_q_online_params

target_value_net_apply_jitted = jax.jit(target_value_net_apply)
value_net_apply_jitted = jax.jit(value_net_apply)
policy_net_apply_jitted = jax.jit(policy_net_apply)
soft_q_net_apply_jitted = jax.jit(soft_q_net_apply)

batched_observation = tree.map_structure(lambda x: x[None], timestep.observation)
action = pendulum_random_agent.batched_actor_step(batched_observation)[0]  # batch size = 1
timestep = pendulum_environment.step(action)

print(policy_net_apply_jitted(policy_target_params, timestep.observation))

  lax_internal._check_user_dtype_supported(dtype, "zeros")


(DeviceArray([0.011], dtype=float32), DeviceArray([0.114], dtype=float32))


In [None]:
def loss(online_params, target_params, obs_tm1, action_tm1, reward_t, discount_t, obs_t, done):
  predicted_q_value = soft_q_net_apply_jitted(online_params, obs_tm1)
  predicted_q_value = jax.lax.stop_gradient(predicted_q_value)
  predicted_q_value = (1. - done[..., None]) * predicted_q_value

  predicted_value    = value_net_apply_jitted(online_params, obs_tm1)
  predicted_value = jax.lax.stop_gradient(predicted_value)
  predicted_q_value = (1. - done[..., None]) * predicted_q_value

  # Training Q function
  target_value    = target_value_net_apply_jitted(target_params, obs_tm1)
  target_value = jax.lax.stop_gradient(target_value)
  target_value = (1. - done[..., None]) * target_value

  target_q_value  = reward_t + discount_t * target_value