<a href="https://colab.research.google.com/github/Aadharc/Policy-Gradient_1/blob/master/Policy_Gradient_Exercise_Zurich_SS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# IFI Summer School - Foundations of Reinforcement Learning (Policy Gradient)

Here we will cover the implementation of a basic reinforcement learning agent with Policy Gradient

The goal of this exercise is to implement the policy gradient theorem, taking the gradient of the discounted returns under the current policy. 



---


The key parts of the implementation that you need to fill in: Implement the loss computation - compute the returns, compute the log probabilities, compute the surrogate objective and baseline loss

More details are found inline - fill out the sections labeled with [FILL THIS OUT]. 

First try this on the point environment, then try it on reacher (change env_name to 'point'/'reacher' in the Perform Training block)



---



In [None]:
#@title Get all dependencies

from IPython.display import clear_output

import collections
from datetime import datetime
import functools
import math
import time
from typing import Any, Callable, Dict, Optional, Sequence, List

try:
  import brax
except ImportError:
  !pip install git+https://github.com/google/brax.git@main
  clear_output()
  import brax

from brax import envs
from brax.envs import to_torch
from brax.io import metrics
from brax.training.agents.ppo import train as ppo
import gym
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F

# additional transfer colab dependencies
import copy
from typing import Tuple, Optional, Union
from IPython.display import HTML, Image
from jax import numpy as jnp
from brax import jumpy as jp
from brax.io import html, image
from brax.envs import env as brax_env
from brax.envs import wrappers as brax_wrappers
from brax.envs import env


# have torch allocate on device first, to prevent JAX from swallowing up all the
# GPU memory. By default JAX will pre-allocate 90% of the available GPU memory:
# https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
v = torch.ones(1, device='cuda')

In [None]:
#@title [FILL THIS OUT] Define REINFORCE agent
class REINFORCEAgent(nn.Module):

  def __init__(self,
               policy_layers: Sequence[int],
               value_layers: Sequence[int],
               discount: float, 
               entropy_weight: float,
               device: str):
    super(REINFORCEAgent, self).__init__()

    # Policy definition
    policy = []
    for w1, w2 in zip(policy_layers, policy_layers[1:]):
      policy.append(nn.Linear(w1, w2))
      policy.append(nn.SiLU())
    policy.pop()  # drop the final activation
    self.policy = nn.Sequential(*policy)

    # Baseline definition
    value = []
    for w1, w2 in zip(value_layers, value_layers[1:]):
      value.append(nn.Linear(w1, w2))
      value.append(nn.SiLU())
    value.pop()  # drop the final activation
    self.value = nn.Sequential(*value)
    self.discount = discount 
    self.entropy_weight = entropy_weight
    self.device = device

  # Distributional definitions (don't need to modify)
  @torch.jit.export
  def dist_create(self, logits):
    """Normal followed by tanh.

    torch.distribution doesn't work with torch.jit, so we roll our own."""
    loc, scale = torch.split(logits, logits.shape[-1] // 2, dim=-1)
    scale = F.softplus(scale) + .001
    return loc, scale

  # Distributional definitions (don't need to modify)
  @torch.jit.export
  def dist_sample_no_postprocess(self, loc, scale):
    return torch.normal(loc, scale)

  # Distributional definitions (don't need to modify)
  @torch.jit.export
  def dist_entropy(self, loc, scale):
    log_normalized = 0.5 * math.log(2 * math.pi) + torch.log(scale)
    entropy = 0.5 + log_normalized
    entropy = entropy * torch.ones_like(loc)
    return entropy.sum(dim=-1)

  # Distributional definitions (don't need to modify)
  @torch.jit.export
  def dist_log_prob(self, loc, scale, dist):
    log_unnormalized = -0.5 * ((dist - loc) / scale).square()
    log_normalized = 0.5 * math.log(2 * math.pi) + torch.log(scale)
    log_prob = log_unnormalized - log_normalized
    return log_prob.sum(dim=-1)

  # Distributional definitions (don't need to modify)
  @torch.jit.export
  def get_logits_action(self, observation):
    logits = self.policy(observation)
    loc, scale = self.dist_create(logits)
    action = self.dist_sample_no_postprocess(loc, scale)
    entropy = self.dist_entropy(loc,  scale)
    log_prob = self.dist_log_prob(loc,  scale, action)
    return logits, action, entropy, log_prob

  # Goal here is to fill out the policy gradient loss function 
  # 1. First compute the return to go from every state in a trajectory 
  #    Use this formula for Return (R_t = \sum_{t'=t}^T \gamma^{t'-t} r(s_{t'}, a_{t'})
  # 2. Compute the baseline at every time step by running s_t through self.value.  
  # 3. Next normalize the returns by subtracting mean and dividing by std (R - R_mean)/R_std
  # 4. Compute surrogate policy loss function as \sum_{i=1}^N \sum_{t=0}^T log \pi(a_t|s_t)(R_t - b)
  #    The gradient of this loss function is taken in the outside training loop. 
  # Some tips: 
  # Go backwards through the trajectory to compute returns 
  # Use mean squared error as loss function for the baseline mean((b(s_t) - R_t)**2)
  # When normalizing returns, make sure to add a small negative number in the denominator
  
  # TODO: Fill this out
  @torch.jit.export
  def update_parameters(self, sample_trajs: List[torch.Tensor]):
      states = sample_trajs[0]
      actions = sample_trajs[1]
      rewards = sample_trajs[2]
      entropies = sample_trajs[3]
      log_probs = sample_trajs[4]

      # Compute returns 
      # TODO

      # Normalize baseline subtracted returns R_norm = (R_t - b) - mean((R_t - b))/std((R_t - b))
      # TODO

      # Compute the surrogate loss function as mean(log pi*R_norm)

      # loss = TODO # Fill this out
      return loss

In [None]:
#@title Sampling and batching environment data:
def sample_trajectory(agent, env, num_steps):
  """Return step data over multple unrolls."""
  observation = env.reset()
  states = []
  actions = []
  rewards = []
  entropies = []
  log_probs = []
  for _ in range(num_steps):
    logits, action, entropy, log_prob = agent.get_logits_action(observation)
    next_observation, reward, done, info = env.step(action)
    states.append(observation[None])
    actions.append(action[None])
    rewards.append(reward[None])
    entropies.append(entropy[None])
    log_probs.append(log_prob[None])
    observation = next_observation
  return [torch.transpose(torch.cat(states), 0, 1), 
          torch.transpose(torch.cat(actions), 0, 1),
          torch.transpose(torch.cat(rewards), 0, 1), 
          torch.transpose(torch.cat(entropies), 0, 1),
          torch.transpose(torch.cat(log_probs), 0, 1)]

In [None]:
#@title Task wrapper

class TaskWrapper(brax_env.Wrapper):

  def __init__(self, env, hide_target=False, 
               num_positions=2, setting=None, norandom_sampling=True):
    super().__init__(env)
    self.num_positions = num_positions
    self.hide_target = hide_target
    self.setting = setting

    def target_positions_sampling(num_positions, rng):
      """Returns random target locations in a random circle slightly above xy plane."""
      targets = []
      for i in range(num_positions):
        rng, rng1, rng2 = jp.random_split(rng, 3)
        dist = .2 * jp.random_uniform(rng1)
        ang = jp.pi * 2. * jp.random_uniform(rng2)
        target_x = dist * jp.cos(ang)
        target_y = dist * jp.sin(ang)
        if norandom_sampling:
          target_x = 0.14142135623730953
          target_y = 0.14142135623730953
        target_z = .01
        target = jp.array([target_x, target_y, target_z]).transpose()
        targets.append(target)
      return jnp.array(targets), rng

    def deterministic_target_position(setting):
      """Returns a specific target location in a random circle slightly above xy plane."""
      if setting == 1:
        v1 = .8
        v2 = .5
      elif setting == 2:
        v1 = .4
        v2 = .0
      elif setting == 3:
        v1 = .8
        v2 = .2
      else:
        raise NotImplementedError('Only settings 1-4 are implemented.')

      dist = .2 * v1
      ang = jp.pi * 2. * v2
      target_x = dist * jp.cos(ang)
      target_y = dist * jp.sin(ang)
      target_z = .01
      target = jp.array([target_x, target_y, target_z]).transpose()
      
      return jnp.array([target])

    if setting == None:
      rng = jp.random_prngkey(seed=42)
      self.target_positions, _ = target_positions_sampling(
        self.num_positions, rng)
    else:
      self.target_positions = deterministic_target_position(setting)

  def new_target_position(self, rng: jp.ndarray) -> Tuple[jp.ndarray, jp.ndarray]:
    """Returns a target location in a random circle slightly above xy plane."""
    rng, rng1  = jp.random_split(rng, 2)

    index = jp.randint(rng1, low = 0, high = self.num_positions)
    target = self.target_positions[index]

    return rng, target

  def reset(self, rng: jp.ndarray) -> brax_env.State:
    rng, rng1, rng2 = jp.random_split(rng, 3)

    self_sys = self.sys 
    qpos = self_sys.default_angle() + jp.random_uniform(
        rng1, (self_sys.num_joint_dof,), -.1, .1)
    qvel = jp.random_uniform(rng2, (self_sys.num_joint_dof,), -.005, .005)

    qp = self_sys.default_qp(joint_angle=qpos, joint_velocity=qvel)
    _, target = self.new_target_position(rng)
    pos = jp.index_update(qp.pos, self._target_idx, target)

    qp = qp.replace(pos=pos)
    obs, _ = self._get_obs(qp, self_sys.info(qp))
    reward, done, zero = jp.zeros(3)
    metrics = {
        'reward_dist': zero,
        'reward_ctrl': zero,
    }
    return brax_env.State(qp, obs, reward, done, metrics)

  def step(self, state, action):
    qp, info = self.sys.step(state.qp, action)
    obs, obs_full = self._get_obs(qp, info)

    # vector from tip to target is last 3 entries of obs vector
    reward_dist = -jp.norm(obs_full[-3:])
    reward_ctrl = -jp.square(action).sum()
    reward = reward_dist + reward_ctrl

    state.metrics.update(
        reward_dist=reward_dist,
        reward_ctrl=reward_ctrl,
    )

    return state.replace(qp=qp, obs=obs, reward=reward)

  def _get_obs(self, qp: brax.QP, info: brax.Info) -> jp.ndarray:
    """Egocentric observation of target and arm body."""
    # (joint_angle,), _ = self.sys.joints[0].angle_vel(qp)
    joint_angle, _ = self.sys.joints[0].angle_vel(qp)

    # qpos:
    # x,y coord of target
    qpos = [qp.pos[self._target_idx, :2]]

    # dist to target and speed of tip
    arm_qps = jp.take(qp, jp.array(self._arm_idx))
    tip_pos, tip_vel = arm_qps.to_world(jp.array([0.11, 0., 0.]))
    tip_to_target = [tip_pos - qp.pos[self._target_idx]]
    cos_sin_angle = [jp.cos(joint_angle), jp.sin(joint_angle)]

    # qvel:
    # velocity of tip
    qvel = [tip_vel[:2]]

    if self.hide_target:
      # return used and full observation
      return (jp.concatenate(cos_sin_angle + qvel), 
              jp.concatenate(cos_sin_angle + qpos + qvel + tip_to_target))
    else:
      return (jp.concatenate(cos_sin_angle + qpos + qvel + tip_to_target),
              jp.concatenate(cos_sin_angle + qpos + qvel + tip_to_target))
    
  @property
  def observation_size(self) -> int:
    """The size of the observation vector returned in step and reset."""
    rng = jp.random_prngkey(0)
    reset_state = self.reset(rng)
    return reset_state.obs.shape[-1]

#@title Additional environment definitions

def create_reacher_env(
           episode_length: int = 1000,
           action_repeat: int = 1,
           auto_reset: bool = True,
           batch_size: Optional[int] = None,
           eval_metrics: bool = False,
           task_wrapper: bool = False,
           num_positions = 1,
           hide_target = False,
           setting = None,
           **kwargs):
  """Creates an Env with a specified brax system."""
  env = brax.envs.reacher.Reacher(**kwargs)
  if task_wrapper:
    env = TaskWrapper(env, num_positions=num_positions, 
                      hide_target=hide_target, setting=setting)
  if episode_length is not None:
    env = brax_wrappers.EpisodeWrapper(env, episode_length, action_repeat)
  if batch_size:
    env = brax_wrappers.VectorWrapper(env, batch_size)
  if auto_reset:
    env = brax_wrappers.AutoResetWrapper(env)
  if eval_metrics:
    env = brax_wrappers.EvalWrapper(env)

  return env  # type: ignore

def create_gym_reacher_env(
                   batch_size: Optional[int] = None,
                   seed: int = 0,
                   backend: Optional[str] = None,
                   **kwargs) -> Union[gym.Env, gym.vector.VectorEnv]:
  """Creates a `gym.Env` or `gym.vector.VectorEnv` from a Brax environment."""
  environment = create_reacher_env(batch_size=batch_size, **kwargs)
  if batch_size is None:
    return brax_wrappers.GymWrapper(environment, seed=seed, backend=backend)
  if batch_size <= 0:
    raise ValueError(
        '`batch_size` should either be None or a positive integer.')
  return brax_wrappers.VectorGymWrapper(environment, seed=seed, backend=backend)


class PointMass(env.Env):
  """Trains an agent to go fast."""

  def __init__(self, **kwargs):
    super().__init__(config='dt: .02', **kwargs)

  def reset(self, rng: jnp.ndarray) -> brax_env.State:
    qp = brax.QP(pos=jnp.zeros(2), vel=jnp.zeros(2), rot=jnp.zeros(2), ang=jnp.zeros(2))
    obs = jnp.zeros(2)
    goal = jnp.ones(2)
    reward, done = jnp.zeros(2)
    return brax_env.State(qp, obs, reward, done, info={'goal': goal})

  def step(self, state: env.State, action: jnp.ndarray) -> brax_env.State:
    pos = state.qp.pos + action * self.sys.config.dt
    qp = state.qp.replace(pos=pos)
    obs = pos.copy()
    reward = -((state.info["goal"][0] - pos[0])**2 + \
               (state.info["goal"][1]- pos[1])**2)
    return state.replace(qp=qp, obs=obs, reward=reward)

  @property
  def observation_size(self):
    return 2

  @property
  def action_size(self):
    return 2

#@title Additional environment definitions
def create_point_env(
           episode_length: int = 1000,
           action_repeat: int = 1,
           auto_reset: bool = True,
           batch_size: Optional[int] = None,
           eval_metrics: bool = False,
           **kwargs):
  """Creates an Env with a specified brax system."""
  env = PointMass(**kwargs)
  if episode_length is not None:
    env = brax_wrappers.EpisodeWrapper(env, episode_length, action_repeat)
  if batch_size:
    env = brax_wrappers.VectorWrapper(env, batch_size)
  if auto_reset:
    env = brax_wrappers.AutoResetWrapper(env)
  if eval_metrics:
    env = brax_wrappers.EvalWrapper(env)

  return env  # type: ignore

def create_gym_point_env(
                   batch_size: Optional[int] = None,
                   seed: int = 0,
                   backend: Optional[str] = None,
                   **kwargs) -> Union[gym.Env, gym.vector.VectorEnv]:
  """Creates a `gym.Env` or `gym.vector.VectorEnv` from a Brax environment."""
  environment = create_point_env(batch_size=batch_size, **kwargs)
  if batch_size is None:
    return brax_wrappers.GymWrapper(environment, seed=seed, backend=backend)
  if batch_size <= 0:
    raise ValueError(
        '`batch_size` should either be None or a positive integer.')
  return brax_wrappers.VectorGymWrapper(environment, seed=seed, backend=backend)

def make_env(env_name: str = 'reacher', 
            num_envs: int = 2048,
            episode_length: int = 100,
            device = 'cuda',
            num_target_positions = 1,
            hide_target = False,
            setting = None):
  if env_name =='reacher':
    env = create_gym_reacher_env(task_wrapper=True, 
                                num_positions=num_target_positions, 
                                hide_target=hide_target,
                                setting=setting,
                                batch_size=num_envs,
                                episode_length=episode_length)
  elif env_name == 'point':
      env = create_gym_point_env(batch_size=num_envs,
                             episode_length=episode_length)

  # automatically convert between jax ndarrays and torch tensors:
  env = to_torch.JaxToTorchWrapper(env, device=device)

  return env

In [None]:
#@title Define training loop

def train(
    env_name: str = 'reacher',
    num_envs: int = 2048,
    episode_length: int = 100,
    device: str = 'cuda',
    num_epochs: int = 200,
    discount: float = 0.99,
    entropy_weight: float = 0.0001, 
    hidden_size: int = 128,
    learning_rate: float = 3e-4,
    num_update_steps = 1
):

  # Define environment  
  env = make_env(env_name, num_envs, episode_length, device)
    
  # env warmup
  env.reset()
  action = torch.zeros(env.action_space.shape).to(device)
  env.step(action)

  # create the agent
  policy_layers = [env.observation_space.shape[-1], hidden_size, hidden_size, env.action_space.shape[-1] * 2]
  value_layers = [env.observation_space.shape[-1], hidden_size, hidden_size, 1]
  agent = REINFORCEAgent(policy_layers, value_layers, discount, entropy_weight, device)
  agent = torch.jit.script(agent.to(device))
  optimizer = optim.Adam(agent.parameters())
    
  # Bookkeeping
  returns = []

  # Training loop
  for iter_num in range(num_epochs):
    # Sample trajectories
    sample_trajs = sample_trajectory(agent, env, episode_length)
    rewards_np = sample_trajs[2].cpu().numpy().sum(axis=-1).mean()
    print("Episode: {}, reward: {}".format(iter_num, rewards_np))
    returns.append(rewards_np)

    # Perform actor-critic update
    for update_num in range(num_update_steps):
        loss = agent.update_parameters(sample_trajs)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

  plt.plot(returns)
  return env, agent

In [None]:
#@title Perform training
# TODO: Adjust parameters if necessary
env_name = 'pointer'
env, agent = train(env_name=env_name, 
                   num_envs=1000, 
                   episode_length=200, 
                   num_epochs=200)

In [None]:
#@title Visualize agent if point

all_obs = []
all_rewards = []
observation = env.reset()
for _ in range(200):
  all_obs.append(observation[:, None, :])
  logits, action, entropy, log_prob = agent.get_logits_action(observation)
  observation, reward, done, info = env.step(action) 
  all_rewards.append(reward[:, None])
all_obs = torch.cat(all_obs, dim=1)
all_obs = all_obs.cpu().detach().numpy()

plt.clf()
plt.cla()
for j in range(1000):
  plt.plot(all_obs[j, :, 0],  all_obs[j, :, 1])
plt.scatter(1, 1, marker='x', s=20)
plt.show()

In [None]:
#@title Perform training
# TODO: Adjust parameters if necessary
env_name = 'reacher'
env, agent = train(env_name=env_name, 
                   num_envs=1000, 
                   episode_length=200, 
                   num_epochs=200)

In [None]:
#@title Visualize agent if reacher

def visualise_agent(env, agent, episode_length, num_episodes=3):
  episodes = []
  for i in range(num_episodes):
    episodes.append([])

  observation = env.reset()
  for _ in range(episode_length):
    logits, action, entropy, log_prob = agent.get_logits_action(observation)
    observation, reward, done, info = env.step(action)    
    batch_state = env.env._state
    for i in range(num_episodes):
      episodes[i].append(brax.QP(batch_state.qp.pos[i], batch_state.qp.rot[i],
                                batch_state.qp.vel[i], batch_state.qp.ang[i]))
      
  rollout = []
  [rollout.extend(ep) for ep in episodes] 

  return rollout

rollouts = visualise_agent(env, agent, episode_length=200, num_episodes=10)
HTML(html.render(env.env._env.sys, rollouts))