<a href="https://colab.research.google.com/github/abhishekunique/zurich-ss/blob/main/Brax_PG_Fast_Zurich_SS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
#@title Import Brax and some helper modules
from IPython.display import clear_output

import collections
from datetime import datetime
import functools
import math
import time
from typing import Any, Callable, Dict, Optional, Sequence, List

try:
  import brax
except ImportError:
  !pip install git+https://github.com/google/brax.git@main
  clear_output()
  import brax

from brax import envs
from brax.envs import to_torch
from brax.io import metrics
import gym
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F

# have torch allocate on device first, to prevent JAX from swallowing up all the
# GPU memory. By default JAX will pre-allocate 90% of the available GPU memory:
# https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
v = torch.ones(1, device='cuda')

Code for replay buffer

In [3]:
class ReplayBuffer(object):
    """Buffer to store environment transitions."""
    def __init__(self, obs_size, action_size, capacity, device):
        self.capacity = capacity
        self.device = device
        
        self.obses = np.empty((capacity, obs_size), dtype=np.float32)
        self.next_obses = np.empty((capacity, obs_size), dtype=np.float32)
        self.actions = np.empty((capacity, action_size), dtype=np.float32)
        self.rewards = np.empty((capacity, 1), dtype=np.float32)
        self.not_dones = np.empty((capacity, 1), dtype=np.float32)

        self.idx = 0
        self.last_save = 0
        self.full = False

    def __len__(self):
        return self.capacity if self.full else self.idx

    def add(self, obs, action, reward, next_obs, done):
        np.copyto(self.obses[self.idx], obs)
        np.copyto(self.actions[self.idx], action)
        np.copyto(self.rewards[self.idx], reward)
        np.copyto(self.next_obses[self.idx], next_obs)
        np.copyto(self.not_dones[self.idx], not done)

        self.idx = (self.idx + 1) % self.capacity
        self.full = self.full or self.idx == 0

    def sample(self, batch_size):
        idxs = np.random.randint(0,
                                 self.capacity if self.full else self.idx,
                                 size=batch_size)

        obses = torch.as_tensor(self.obses[idxs], device=self.device).float()
        actions = torch.as_tensor(self.actions[idxs], device=self.device)
        rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
        next_obses = torch.as_tensor(self.next_obses[idxs],
                                     device=self.device).float()
        not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)

        return obses, actions, rewards, next_obses, not_dones

Actor critic agent

In [5]:
class Agent(nn.Module):
  def __init__(self,
               policy_layers: Sequence[int],
               qf_layers: Sequence[int],
               target_qf_layers: Sequence[int],
               discount: float, 
               entropy_weight: float,
               device: str):
    super(Agent, self).__init__()

    policy = []
    for w1, w2 in zip(policy_layers, policy_layers[1:]):
      policy.append(nn.Linear(w1, w2))
      policy.append(nn.SiLU())
    policy.pop()  # drop the final activation
    self.policy = nn.Sequential(*policy)

    qf = []
    for w1, w2 in zip(qf_layers, qf_layers[1:]):
      qf.append(nn.Linear(w1, w2))
      qf.append(nn.SiLU())
    qf.pop()  # drop the final activation
    self.qf = nn.Sequential(*qf)
    
    target_qf = []
    for w1, w2 in zip(target_qf_layers, target_qf_layers[1:]):
      target_qf.append(nn.Linear(w1, w2))
      target_qf.append(nn.SiLU())
    target_qf.pop()  # drop the final activation
    self.target_qf = nn.Sequential(*target_qf)
    
    self.discount = discount 
    self.entropy_weight = entropy_weight
    self.device = device

  @torch.jit.export
  def dist_create(self, logits):
    """Normal followed by tanh.

    torch.distribution doesn't work with torch.jit, so we roll our own."""
    loc, scale = torch.split(logits, logits.shape[-1] // 2, dim=-1)
    scale = F.softplus(scale) + .001
    return loc, scale

  @torch.jit.export
  def dist_sample_no_postprocess(self, loc, scale):
    return torch.normal(loc, scale)

  @torch.jit.export
  def dist_entropy(self, loc, scale):
    log_normalized = 0.5 * math.log(2 * math.pi) + torch.log(scale)
    entropy = 0.5 + log_normalized
    entropy = entropy * torch.ones_like(loc)
    return entropy.sum(dim=-1)

  @torch.jit.export
  def dist_log_prob(self, loc, scale, dist):
    log_unnormalized = -0.5 * ((dist - loc) / scale).square()
    log_normalized = 0.5 * math.log(2 * math.pi) + torch.log(scale)
    log_prob = log_unnormalized - log_normalized
    return log_prob.sum(dim=-1)

  @torch.jit.export
  def get_logits_action(self, observation):
    logits = self.policy(observation)
    loc, scale = self.dist_create(logits)
    action = self.dist_sample_no_postprocess(loc, scale)
    entropy = self.dist_entropy(loc,  scale)
    log_prob = self.dist_log_prob(loc,  scale, action)
    return logits, action, entropy, log_prob

  # TODO: Check maximum entropy
  @torch.jit.export
  def compute_losses(self, obs_t, actions_t, rewards_t, next_obs_t, not_dones_t):
    # Policy loss
    logits = self.policy(obs_t)
    loc, scale = self.dist_create(logits)
    new_obs_actions = self.dist_sample_no_postprocess(loc, scale)

    # TODO: Change this over to REINFORCE loss maybe instead of reparameterization??
    q_new_actions = self.qf(torch.cat([obs_t, new_obs_actions], dim=-1))
    policy_loss = -q_new_actions.mean()

    # Compute Bellman loss
    q1_pred = self.qf(torch.cat([obs_t, actions_t], dim=-1))
    logits = self.policy(next_obs_t)
    loc, scale = self.dist_create(logits)
    new_next_actions = self.dist_sample_no_postprocess(loc, scale)

    target_q_values = self.target_qf(torch.cat([next_obs_t, new_next_actions], dim=-1))
    q_target = rewards_t + not_dones_t * self.discount * target_q_values

    # L2 error on bellman
    qf_loss = torch.linalg.norm(q1_pred - q_target.detach(), dim=-1).mean()

    return policy_loss, qf_loss

  @torch.jit.export
  def update_parameters(self, sample_trajs: List[torch.Tensor]):
      states = sample_trajs[0]
      actions = sample_trajs[1]
      rewards = sample_trajs[2]
      entropies = sample_trajs[3]
      log_probs = sample_trajs[4]
    
      # Bookkeeping
      R_EPS = 1e-9
      R = torch.zeros(rewards.shape[0],rewards.shape[1]).cuda()
      running_r = torch.zeros(rewards.shape[0],).cuda()
      baseline_losses = torch.zeros(rewards.shape[1],).cuda()
      
      # Compute discounted cumulative sum TODO: Check this
      for j in range(rewards.shape[1]):
          i = rewards.shape[1] - 1 - j
          running_r = self.discount * running_r + rewards[:, i]
          baseline_rpred = self.value(states[:, i])[:, 0]
          R[:, i] = running_r - baseline_rpred # Subtract the baseline
          baseline_loss = torch.sum((baseline_rpred - running_r)**2)
          baseline_losses[i] = baseline_loss
          
      # Normalize advantages
      R_mean = torch.mean(R)
      R_std = torch.std(R)
      R = (R - R_mean) / (R_std + R_EPS)
      
      # Compute loss
      loss = -(log_probs*R).sum() - self.entropy_weight*entropies.sum()
      loss = loss / len(rewards)
      baseline_loss = baseline_losses.sum() / len(rewards)
      loss += baseline_loss
      return loss

In [None]:
def flip_target(net, target_net, tau):
    for param, target_param in zip(net.parameters(), target_net.parameters()):
        target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

In [None]:
def sample_trajectory(agent, env, num_steps, replay_buffer):
  """Return step data over multple unrolls."""
  observation = env.reset()
  states = []
  actions = []
  rewards = []
  entropies = []
  log_probs = []
  for _ in range(num_steps):
    logits, action, entropy, log_prob = agent.get_logits_action(observation)
    next_observation, reward, done, info = env.step(action)
    for j in range(next_observation.shape[0]):
        replay_buffer.add(observation.cpu().detach().numpy()[j], 
                          action.cpu().detach().numpy()[j], 
                          reward.cpu().detach().numpy()[j], 
                          next_observation.cpu().detach().numpy()[j], 
                          done.cpu().detach().numpy()[j])
    
    states.append(observation[None])
    actions.append(action[None])
    rewards.append(reward[None])
    entropies.append(entropy[None])
    log_probs.append(log_prob[None])
    
    observation = next_observation
  return [torch.transpose(torch.cat(states), 0, 1), 
          torch.transpose(torch.cat(actions), 0, 1),
          torch.transpose(torch.cat(rewards), 0, 1), 
          torch.transpose(torch.cat(entropies), 0, 1),
          torch.transpose(torch.cat(log_probs), 0, 1)]


In [None]:
def make_env(env_name, num_envs, episode_length, device):
  gym_name = f'brax-{env_name}-v0'
  if gym_name not in gym.envs.registry.env_specs:
    entry_point = functools.partial(envs.create_gym_env, env_name=env_name)
    gym.register(gym_name, entry_point=entry_point)
  env = gym.make(gym_name, batch_size=num_envs, episode_length=episode_length)
  env = to_torch.JaxToTorchWrapper(env, device=device)
  return env

In [8]:
def train(
    env_name: str = 'reacher',
    num_envs: int = 2048,
    episode_length: int = 100,
    device: str = 'cuda',
    num_epochs: int = 200,
    discount: float = 0.99,
    entropy_weight: float = 0.0001, 
    hidden_size: int = 128,
    learning_rate: float = 3e-4,
    capacity = 10000,
    batch_size = 32,
    target_flip_freq = 10,
    target_flip_tau = 5e-3,
    num_update_steps = 100
):
  # Define environment  
  env = make_env(env_name, num_envs, episode_length, device)
    
  # env warmup
  env.reset()
  action = torch.zeros(env.action_space.shape).to(device)
  env.step(action)

    
  # Define replay buffer
  replay_buffer = ReplayBuffer(env.observation_space.shape[1], 
                               env.action_space.shape[1], 
                               capacity, 
                               device)

  # create the agent
  policy_layers = [
      env.observation_space.shape[-1], 
      hidden_size, 
      hidden_size, 
      env.action_space.shape[-1] * 2
  ]
  qf_layers = [env.observation_space.shape[-1] + env.action_space.shape[-1], 
                  hidden_size, 
                  hidden_size, 
                  1]
  target_qf_layers = [env.observation_space.shape[-1] + env.action_space.shape[-1], 
                  hidden_size, 
                  hidden_size, 
                  1]
  agent = Agent(policy_layers, qf_layers, target_qf_layers, discount, entropy_weight, device)
  agent = torch.jit.script(agent.to(device))
  policy_optimizer = optim.Adam(agent.policy.parameters())
  qf_optimizer = optim.Adam(agent.qf.parameters())

  # Copy parameters initially
  flip_target(agent.qf, agent.target_qf, 1.0)
    
  returns = []

  for iter_num in range(num_epochs):
    # Sample trajectories
    sample_trajs = sample_trajectory(agent, env, episode_length, replay_buffer)
    rewards_np = sample_trajs[2].cpu().numpy().sum(axis=-1).mean()
    print("Episode: {}, reward: {}".format(iter_num, rewards_np))
    returns.append(rewards_np)
    
    # Perform actor-critic update
    for update_num in range(num_update_steps):
        obs_t, actions_t, rewards_t, next_obs_t, not_dones_t = replay_buffer.sample(batch_size)
        policy_loss, qf_loss = agent.compute_losses(obs_t, actions_t, rewards_t, next_obs_t, not_dones_t)
#         loss = policy_loss + qf_loss
        
        policy_optimizer.zero_grad()
        policy_loss.backward()
        policy_optimizer.step()
        
        qf_optimizer.zero_grad()
        qf_loss.backward()
        qf_optimizer.step()
    
    # Flip target network originally
    if iter_num % target_flip_freq == 0:
        flip_target(agent.qf, agent.target_qf, target_flip_tau)
        

  plt.plot(returns)

In [9]:
train()



XlaRuntimeError: INVALID_ARGUMENT: DLPack tensor is on GPU, but no GPU backend was provided.