In [None]:
#@title Install Brax and some helper modules
#@markdown ## ⚠️ PLEASE NOTE:
#@markdown This colab runs best using a TPU runtime.  From the Colab menu, choose Runtime > Change Runtime Type, then select **'TPU'** in the dropdown.

from datetime import datetime
import functools
import os

from IPython.display import HTML, clear_output

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

try:
  import brax
except ImportError:
  !pip install git+https://github.com/google/brax.git@main
  clear_output()
  import brax

from brax import envs
from brax import jumpy as jp
from brax.io import html
from brax.io import model
from brax.training.agents.ppo import train as ppo
from brax.training.agents.sac import train as sac
from brax.training.agents.apg import train as apg

if 'COLAB_TPU_ADDR' in os.environ:
  from jax.tools import colab_tpu
  colab_tpu.setup_tpu()

First let's pick an environment to train an agent:

In [None]:
#@title Preview a Brax environment { run: "auto" }

env_name = "ant"
env = envs.get_environment(env_name=env_name)
state = env.reset(rng=jp.random_prngkey(seed=0))

HTML(html.render(env.sys, [state.qp]))

#Load policy

For this method, there is no training phase. All "learning" is done directly in the policy

In [None]:
from brax import envs
from brax.training import types

from typing import Tuple
import functools


def make_policy(env) -> types.Policy:

    step = jax.jit(env.step)

    def run_ahead_carry(carry, action):
        state =step(carry, action)
        return state, state.reward

    def calc_reward(state, actions):
        rewards = jax.lax.scan(run_ahead_carry,  state, actions)[1]
        return jnp.mean(rewards)

    run_ahead_grad = jax.grad(jax.jit(calc_reward), argnums=1)

    def policy(state : envs.State, actions_arr, epochs, m_fact, lr) -> Tuple[types.Action, types.Extra]:

        def learn(carry, _):
            actions, momentum = carry
            actions_grad = run_ahead_grad(state, actions)

            momentum = m_fact * momentum + (1-m_fact) * actions_grad
            return (actions + momentum * lr, momentum), _
        
        momentum_start = jnp.zeros(actions_arr.shape)
        actions_arr, _ = jax.lax.scan(learn, (actions_arr, momentum_start), jnp.arange(epochs))[0]

        return actions_arr

    return policy

In [None]:
env = envs.create('ant')
policy = make_policy(env)

step = jax.jit(env.step)
jit_policy = jax.jit(policy, static_argnums=2)

def train(state, initial_actions, num_steps, m_fact, lr):
    state_sequence = []

    actions = initial_actions
    total_reward = 0

    for i in range(num_steps):
        act = jit_policy(state, actions, 10, m_fact, lr)
        # print(f"Step {i}: {act[0]}")

        state_sequence.append(state)
        state = step(state, act[0])
        total_reward += state.reward
        act = jnp.pad(act[1:], ((0,1), (0,0)), 'constant', constant_values=(0))
    return state_sequence, total_reward

# Generate random initial move

def random_action(key, shape):
    (shape_0, shape_1) = shape
    act = jax.random.uniform(key, (1, shape_1))
    act = (act - 0.5) / 5
    # return act

    return jnp.pad(act, ((0, shape_0 - 1), (0, 0)), 'symmetric')

In [None]:
env = envs.create('ant')
policy = make_policy(env)

step = jax.jit(env.step)
jit_policy = jax.jit(policy, static_argnums=2)

lookforward = 20
num_steps = 60

key = jax.random.PRNGKey(0)

results = []
best_reward = 0
best_index = 0

lr = 2
i = 5

key = jax.random.PRNGKey(i)
actions = random_action(key, (num_steps, env.action_size))
rollout2, reward = train(env.reset(key), actions, num_steps, 0, lr)
# results.append((ro, reward))



print(f"Reward: {reward}")

# actions = jnp.zeros((num_steps, env.action_size))
# rollout2, reward = train(env.reset(0), actions, num_steps)
# print(f"Reward: {reward}")

HTML(html.render(env.sys, [s.qp for s in rollout2]))

Reward: 69.41455841064453


In [None]:
render = html.render(env.sys, [s.qp for s in rollout2])
HTML(render)

In [None]:
#This gives the raw html for the above display, this can be ignored
print(render)