In [None]:
!pip install --upgrade gym --quiet

In [None]:
#@title Import Brax and some helper modules
from IPython.display import clear_output

import collections
from datetime import datetime
import functools
import math
import os
import time
from typing import Any, Callable, Dict, Optional, Sequence

import brax

from brax import envs
from brax.envs.wrappers import gym as gym_wrapper
from brax.envs.wrappers import torch as torch_wrapper
from brax.io import metrics
from brax.training.agents.ppo import train as ppo
import gym
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F

In [None]:
class Agent(nn.Module):
  """Standard PPO Agent with GAE and observation normalization."""

  def __init__(self,
               policy_layers: Sequence[int],
               entropy_cost: float,
               discounting: float,
               reward_scaling: float,
               device: str):
    super(Agent, self).__init__()

    policy = []
    for w1, w2 in zip(policy_layers, policy_layers[1:]):
      policy.append(nn.Linear(w1, w2))
      policy.append(nn.SiLU())
    policy.pop()  # drop the final activation
    self.policy = nn.Sequential(*policy)
    self.num_steps = torch.zeros((), device=device)
    self.running_mean = torch.zeros(policy_layers[0], device=device)
    self.running_variance = torch.zeros(policy_layers[0], device=device)

    self.entropy_cost = entropy_cost
    self.discounting = discounting
    self.reward_scaling = reward_scaling
    self.lambda_ = 0.95
    self.epsilon = 0.3
    self.device = device

  @torch.jit.export
  def dist_create(self, logits):
    """Normal followed by tanh.

    torch.distribution doesn't work with torch.jit, so we roll our own."""
    loc, scale = torch.split(logits, logits.shape[-1] // 2, dim=-1)
    scale = F.softplus(scale) + .001
    return loc, scale

  @torch.jit.export
  def dist_sample_no_postprocess(self, loc, scale):
    return torch.normal(loc, scale)

  @classmethod
  def dist_postprocess(cls, x):
    return torch.tanh(x)

  @torch.jit.export
  def dist_entropy(self, loc, scale):
    log_normalized = 0.5 * math.log(2 * math.pi) + torch.log(scale)
    entropy = 0.5 + log_normalized
    entropy = entropy * torch.ones_like(loc)
    dist = torch.normal(loc, scale)
    log_det_jacobian = 2 * (math.log(2) - dist - F.softplus(-2 * dist))
    entropy = entropy + log_det_jacobian
    return entropy.sum(dim=-1)

  @torch.jit.export
  def dist_log_prob(self, loc, scale, dist):
    log_unnormalized = -0.5 * ((dist - loc) / scale).square()
    log_normalized = 0.5 * math.log(2 * math.pi) + torch.log(scale)
    log_det_jacobian = 2 * (math.log(2) - dist - F.softplus(-2 * dist))
    log_prob = log_unnormalized - log_normalized - log_det_jacobian
    return log_prob.sum(dim=-1)

  @torch.jit.export
  def update_normalization(self, observation):
    self.num_steps += observation.shape[0] * observation.shape[1]
    input_to_old_mean = observation - self.running_mean
    mean_diff = torch.sum(input_to_old_mean / self.num_steps, dim=(0, 1))
    self.running_mean = self.running_mean + mean_diff
    input_to_new_mean = observation - self.running_mean
    var_diff = torch.sum(input_to_new_mean * input_to_old_mean, dim=(0, 1))
    self.running_variance = self.running_variance + var_diff

  @torch.jit.export
  def normalize(self, observation):
    variance = self.running_variance / (self.num_steps + 1.0)
    variance = torch.clip(variance, 1e-6, 1e6)
    return ((observation - self.running_mean) / variance.sqrt()).clip(-5, 5)

  @torch.jit.export
  def get_logits_action(self, observation):
    observation = self.normalize(observation)
    logits = self.policy(observation)
    loc, scale = self.dist_create(logits)
    action = self.dist_sample_no_postprocess(loc, scale)
    return logits, action

  @torch.jit.export
  def chamfer_loss(self, pred, target):
        pred = pred.unsqueeze(2)  # (batch_size, seq_len, 1, feature_dim)
        target = target.unsqueeze(1)  # (batch_size, 1, seq_len, feature_dim)
        pred = pred.repeat(1, 1, target.shape[2], 1)
        target = target.repeat(1, pred.shape[1], 1, 1)
        distances = torch.sqrt(((pred - target) ** 2).mean(dim=-1))  # (batch_size, seq_len, seq_len)
        min_distances, _ = distances.min(dim=-1)
        return min_distances.mean()
        
  @torch.jit.export
  def loss(self, td: Dict[str, torch.Tensor], demo_traj, demo_traj_action):
      observation = self.normalize(td['observation'])
      policy_logits = self.policy(observation)

      # Compute rollout trajectory and demo trajectory normalization
      rollout_traj = torch.cat([td['qp_list']['pos'].reshape((td['qp_list']['pos'].shape[0], td['qp_list']['pos'].shape[1], -1)),
                                td['qp_list']['rot'].reshape((td['qp_list']['rot'].shape[0], td['qp_list']['rot'].shape[1], -1)),
                                td['qp_list']['vel'].reshape((td['qp_list']['vel'].shape[0], td['qp_list']['vel'].shape[1], -1)),
                                td['qp_list']['ang'].reshape((td['qp_list']['ang'].shape[0], td['qp_list']['ang'].shape[1], -1))], axis=-1)

      self.update_normalization(observation)
      demo_traj_ = self.normalize(demo_traj)

      # Calculate state chamfer loss
      cf_loss = self.chamfer_loss(rollout_traj, demo_traj_)

      # Calculate action chamfer loss
      pred_action = td['action_list']
      pred_demo_action = demo_traj_action
      cf_action_loss = self.chamfer_loss(pred_action, pred_demo_action)

      # Calculate entropy loss
      loc, scale = torch.chunk(policy_logits, 2, dim=-1)
      sigma_list = F.softplus(scale) + 1e-6  # Small epsilon for numerical stability
      entropy_loss = -0.5 * torch.log(2 * torch.pi * sigma_list ** 2).mean()

      # Combine losses
      final_loss = cf_loss + cf_action_loss + self.entropy_cost * entropy_loss
      final_loss = torch.tanh(final_loss)

      return final_loss

In [None]:
StepData = collections.namedtuple(
    'StepData',
    ('observation', 'logits', 'action', 'reward', 'done', 'truncation'))


def sd_map(f: Callable[..., torch.Tensor], *sds) -> StepData:
  """Map a function over each field in StepData."""
  items = {}
  keys = sds[0]._asdict().keys()
  for k in keys:
    items[k] = f(*[sd._asdict()[k] for sd in sds])
  return StepData(**items)


def eval_unroll(agent, env, length):
  """Return number of episodes and average reward for a single unroll."""
  observation = env.reset()
  episodes = torch.zeros((), device=agent.device)
  episode_reward = torch.zeros((), device=agent.device)
  for _ in range(length):
    _, action = agent.get_logits_action(observation)
    observation, reward, done, _ = env.step(Agent.dist_postprocess(action))
    episodes += torch.sum(done)
    episode_reward += torch.sum(reward)
  return episodes, episode_reward / episodes


def train_unroll(agent, env, observation, num_unrolls, unroll_length):
  """Return step data over multple unrolls."""
  sd = StepData([], [], [], [], [], [])
  for _ in range(num_unrolls):
    one_unroll = StepData([observation], [], [], [], [], [])
    for _ in range(unroll_length):
      logits, action = agent.get_logits_action(observation)
      observation, reward, done, info = env.step(Agent.dist_postprocess(action))
      one_unroll.observation.append(observation)
      one_unroll.logits.append(logits)
      one_unroll.action.append(action)
      one_unroll.reward.append(reward)
      one_unroll.done.append(done)
      one_unroll.truncation.append(info['truncation'])
    one_unroll = sd_map(torch.stack, one_unroll)
    sd = sd_map(lambda x, y: x + [y], sd, one_unroll)
  td = sd_map(torch.stack, sd)
  return observation, td


def train(
    env_name: str = 'ant',
    num_envs: int = 2048,
    episode_length: int = 1000,
    device: str = 'cuda',
    num_timesteps: int = 30_000_000,
    eval_frequency: int = 10,
    unroll_length: int = 5,
    batch_size: int = 1024,
    num_minibatches: int = 32,
    num_update_epochs: int = 4,
    reward_scaling: float = .1,
    entropy_cost: float = 1e-2,
    discounting: float = .97,
    learning_rate: float = 3e-4,
    progress_fn: Optional[Callable[[int, Dict[str, Any]], None]] = None,
):

  env = envs.create(env_name, batch_size=num_envs,
                    episode_length=episode_length,
                    backend='spring')
  env = gym_wrapper.VectorGymWrapper(env)
  # automatically convert between jax ndarrays and torch tensors:
  env = torch_wrapper.TorchWrapper(env, device=device)

  # env warmup
  env.reset()
  action = torch.zeros(env.action_space.shape).to(device)
  env.step(action)

  # create the agent
  policy_layers = [
      env.observation_space.shape[-1], 64, 64, env.action_space.shape[-1] * 2
  ]

  agent = Agent(policy_layers, entropy_cost, discounting,
                reward_scaling, device)
  agent = torch.jit.script(agent.to(device))
  optimizer = optim.Adam(agent.parameters(), lr=learning_rate)

  sps = 0
  total_steps = 0
  total_loss = 0

  # Expert Demo
  args.logdir = f"logs/{args.env}/{args.env}_ep_len{args.ep_len}_num_envs{args.num_envs}_lr{args.lr}_trunc_len{args.trunc_len}" \
                f"_max_it{args.max_it}_max_grad_norm{args.max_grad_norm}_re_dis{args.reverse_discount}_ef_{args.entropy_factor}" \
                f"_df_{args.deviation_factor}_acf_{args.action_cf_factor}_l2loss_{args.l2}_il_{args.il}_ILD_{args.ILD}" \
                f"/seed{args.seed}"
  demo_traj = np.load(f"expert/{args.env}_traj_state.npy")
  demo_traj = jnp.array(demo_traj)[:args.ep_len][:, None, ...].repeat(args.num_envs, 1)
  demo_traj_action = np.load(f"expert/{args.env}_traj_action.npy")
  demo_traj_action = jnp.array(demo_traj_action)[:args.ep_len][:, None, ...].repeat(args.num_envs, 1)
  demo_traj_obs = np.load(f"expert/{args.env}_traj_obs.npy")
  demo_traj_obs = jnp.array(demo_traj_obs)[:args.ep_len][:, None, ...].repeat(args.num_envs, 1)
  reverse_discounts = jnp.array([args.reverse_discount ** i for i in range(args.ep_len, 0, -1)])[None, ...]
  reverse_discounts = reverse_discounts.repeat(args.num_envs, 0)

  for eval_i in range(eval_frequency + 1):
    if progress_fn:
      t = time.time()
      with torch.no_grad():
        episode_count, episode_reward = eval_unroll(agent, env, episode_length)
      duration = time.time() - t
      # TODO: only count stats from completed episodes
      episode_avg_length = env.num_envs * episode_length / episode_count
      eval_sps = env.num_envs * episode_length / duration
      progress = {
          'eval/episode_reward': episode_reward,
          'eval/completed_episodes': episode_count,
          'eval/avg_episode_length': episode_avg_length,
          'speed/sps': sps,
          'speed/eval_sps': eval_sps,
          'losses/total_loss': total_loss,
      }
      progress_fn(total_steps, progress)

    if eval_i == eval_frequency:
      break

    observation = env.reset()
    num_steps = batch_size * num_minibatches * unroll_length
    num_epochs = num_timesteps // (num_steps * eval_frequency)
    num_unrolls = batch_size * num_minibatches // env.num_envs
    total_loss = 0
    t = time.time()
    for _ in range(num_epochs):
      observation, td = train_unroll(agent, env, observation, num_unrolls,
                                     unroll_length)

      # make unroll first
      def unroll_first(data):
        data = data.swapaxes(0, 1)
        return data.reshape([data.shape[0], -1] + list(data.shape[3:]))
      td = sd_map(unroll_first, td)

      # update normalization statistics
      agent.update_normalization(td.observation)

      for _ in range(num_update_epochs):
        # shuffle and batch the data
        with torch.no_grad():
          permutation = torch.randperm(td.observation.shape[1], device=device)
          def shuffle_batch(data):
            data = data[:, permutation]
            data = data.reshape([data.shape[0], num_minibatches, -1] +
                                list(data.shape[2:]))
            return data.swapaxes(0, 1)
          epoch_td = sd_map(shuffle_batch, td)

        for minibatch_i in range(num_minibatches):
          td_minibatch = sd_map(lambda d: d[minibatch_i], epoch_td)
          loss = agent.loss(td_minibatch._asdict(), demo_traj, demo_traj_action)
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()
          total_loss += loss

    duration = time.time() - t
    total_steps += num_epochs * num_steps
    total_loss = total_loss / (num_epochs * num_update_epochs * num_minibatches)
    sps = num_epochs * num_steps / duration

In [None]:
train_sps = []

def progress(_, metrics):
  if 'training/sps' in metrics:
    train_sps.append(metrics['training/sps'])

ppo.train(
    environment=envs.create(env_name='ant', backend='spring'),
    num_timesteps = 30_000_000, num_evals = 10, reward_scaling = .1,
    episode_length = 1000, normalize_observations = True, action_repeat = 1,
    unroll_length = 5, num_minibatches = 32, num_updates_per_batch = 4,
    discounting = 0.97, learning_rate = 3e-4, entropy_cost = 1e-2,
    num_envs = 2048, batch_size = 1024, progress_fn = progress)

print(f'train steps/sec: {np.mean(train_sps[1:])}')