# Training in Brax

Once an environment is created in brax, we can quickly train it using brax's built-in training algorithms. Let's try it out!

In [1]:
%load_ext autoreload
%autoreload 2

<h2>Sample Graident Calculations</h2>
The following showcases a simplistic example of how to compute a state action gradient over a simulation step in Brax

In [2]:
import functools
import jax
from jax import numpy as jp

from brax import envs
from brax.io import model
from brax.training.agents.ppo import train as ppo



env_name = 'ant'
backend = 'generalized'  # @param ['generalized', 'positional', 'spring']

env = envs.get_environment(env_name=env_name,
                           backend=backend)

state = env.reset(rng=jax.random.PRNGKey(seed=0))

train_fn = functools.partial(ppo.train,  num_timesteps=50_000_000, num_evals=10, reward_scaling=10, episode_length=1000, normalize_observations=True, action_repeat=1,
                             unroll_length=5, num_minibatches=32, num_updates_per_batch=4, discounting=0.97, learning_rate=3e-4, entropy_cost=1e-2, num_envs=4096, batch_size=2048, seed=1)

make_inference_fn, params, _ = train_fn(environment=env)
inference_fn = make_inference_fn(params)

env = envs.create(env_name=env_name, backend=backend)

def trimmed_state_step(state, action):
    new_state = env.step(state, action)
    return new_state.obs

def example_get_obs_grad(pipeline_grad):
    return jp.concatenate([pipeline_grad.q[:,2:], pipeline_grad.qd], axis=-1)

rng = jax.random.PRNGKey(seed=1)
state = env.reset(rng=rng)

act_rng, rng = jax.random.split(rng)
act, _ = inference_fn(state.obs, act_rng)

new_q = trimmed_state_step(state, act)

print(f"observation: {new_q.shape}")
print(new_q)

pipeline_grad=jax.jacobian(trimmed_state_step, argnums=0)(state, act).pipeline_state
dobs_dstate = example_get_obs_grad(pipeline_grad)

print(f"pipeline state gradient: {dobs_dstate.shape}")
print(dobs_dstate)

KeyboardInterrupt: 

In [2]:
import functools
import jax
import os

from datetime import datetime
from jax import numpy as jp
import matplotlib.pyplot as plt

from IPython.display import HTML, clear_output,display, IFrame

import brax


import flax
from brax import envs
import torch
from brax.io import model
from brax.io import json
from brax.io import html
from brax.training.agents.ppo import train as ppo
from brax.training.agents.sac import train as sac
from torch2jax import j2t, t2j



class BraxHandler():

    def __init__(self, env_name, backend,rng_seed=0):
        self.env_name = env_name
        self.backend = backend
        self.env = envs.create(env_name=env_name, backend=backend)
        rng = jax.random.PRNGKey(seed=rng_seed)
        self.jit_env_set_and_step = jax.jit(self.env.set_state_and_step)
        self.jit_env_reset = jax.jit(self.env.reset)
        self.jit_env_rng = jax.random.PRNGKey(seed=1)
        self.init_state = self.jit_env_reset(rng=rng)
        self.rollout=[]
        self.rollout.append(self.init_state.pipeline_state)

    def perform_step(self,input_state, act, render=False):
        # Step environment
        state = self.jit_env_set_and_step(input_state, act)
        if render:
            self.rollout.append(input_state.pipeline_state)
        return state

    def get_rollout(self):
        return self.rollout
    
    def save_rollout(self, filename):
        with open(filename, 'w') as f:
            f.write(html.render(self.env.sys.tree_replace({'opt.timestep': self.env.dt}), self.rollout))


# Test the class
env_name = 'walker2d_mpc'
backend = 'generalized'
brax_handler = BraxHandler(env_name, backend)
state = brax_handler.init_state

def step_obs(state, act):
    return brax_handler.perform_step(state, act).obs

action_grad = jax.jit(jax.jacobian(step_obs,argnums=1))
state_grad = jax.jit(jax.jacobian(step_obs, argnums=0))

# for _ in range(1):
#     act = jp.ones(brax_handler.env.action_size, dtype=float)
#     # old_state=jp.copy(state)
#     state = brax_handler.perform_step(state, act,render=True)
#     print(state.pipeline_state.q.shape)
#     print(state.pipeline_state.q)
#     print(state.pipeline_state.qd.shape)
#     print(state.pipeline_state.qd)
#     print(state.obs.shape)
#     print(state.obs)
#     # # compute grads
#     # print(action_grad(state, act))
#     # print(state_grad(state, act))

# brax_handler.save_rollout(f"{env_name}_test.html")

In [2]:
import functools
import jax
import os

from datetime import datetime
from jax import numpy as jp
import matplotlib.pyplot as plt

from IPython.display import HTML, clear_output,display, IFrame

import brax


import flax
from brax import envs
from brax.io import model
from brax.io import json
from brax.io import html
from brax.training.agents.ppo import train as ppo
from brax.training.agents.sac import train as sacaction_grad
from torch2jax import j2t, t2j
from mrf_swarm.sim.robot import DynamicsModel
import torch
from functools import singledispatch, update_wrapper
from brax.envs.base import PipelineEnv, State

class MultiDispatch:
    def __init__(self, default):
        self.default = default
        self.dispatch = {}
        update_wrapper(self, default)

    def register(self, *types):
        def wrapper(func):
            self.dispatch[types] = func
            return func
        return wrapper

    def __call__(self, *args):
        types = tuple(type(arg) for arg in args)
        func = self.dispatch.get(types, self.default)
        return func(*args)

class BraxHandler():

    def __init__(self, env_name, backend,rng_seed=0):
        self.env_name = env_name
        self.backend = backend
        self.env = envs.create(env_name=env_name, backend=backend)
        rng = jax.random.PRNGKey(seed=rng_seed)
        self.jit_env_set_and_step = jax.jit(self.env.set_state_and_step)
        self.jit_env_reset = jax.jit(self.env.reset)
        self.jit_env_rng = jax.random.PRNGKey(seed=1)
        self.init_state = self.jit_env_reset(rng=rng)
        self.rollout=[]
        self.rollout.append(self.init_state.pipeline_state)
        self.action_grad1 = jax.jit(jax.jacobian(self.step_obs,argnums=1))
        self.state_grad1 = jax.jit(jax.jacobian(self.step_obs, argnums=0))
        self.action_grad = jax.jit(jax.jacobian(self.full_grad,argnums=1))
        self.state_grad = jax.jit(jax.jacobian(self.full_grad, argnums=0))
        self.jit_generate_state_from_tensor = jax.jit(self.env.generate_state_from_tensor)

    def step_obs(self,state, act):
        return self.perform_step(state, act).obs
    
    def full_grad(self,state, act):
        grad = self.perform_step(state, act)
        # print(grad.pipeline_state)
        return grad

    def perform_step(self,input_state, act, render=False):
        # Step environment
        state = self.jit_env_set_and_step(input_state, act)
        if render:
            self.rollout.append(input_state.pipeline_state)
        return state
    
    def perform_step_w_grad(self,input_state, act, render=False):
        # Step environment
        next_state = self.jit_env_set_and_step(input_state, act)

        # Compute gradients
        pipeline_state_grad = self.state_grad(input_state, act)
        action_grad = self.action_grad(input_state, act)
       
        # Convert gradients to useful data 
        # print(action_grad)
        obs_grad,action_grad = self.env.get_obs_grad(pipeline_state_grad,action_grad)
        
        
        # Maintain History if specified
        if render:
            self.add_to_rollout(input_state)

        return next_state, obs_grad, action_grad

    def get_rollout(self):
        return self.rollout
    
    def add_to_rollout(self, state):
        self.rollout.append(state.pipeline_state)
    
    def save_rollout(self, filename):
        with open(filename, 'w') as f:
            f.write(html.render(self.env.sys.tree_replace({'opt.timestep': self.env.dt}), self.rollout))

class BraxTorchStateMapper():
    def __init__(self, state,action, state_grad, act_grad):
        self.brax_state = state
        self.brax_action = action

        # Convert fields to torch tensors
        self.x = j2t(self.brax_state.obs)
        self.u = j2t(self.brax_action)
        self.act_grad=j2t(act_grad)
        self.state_grad=j2t(state_grad)

class BraxHelperModel(DynamicsModel):
    def __init__(self, brax_handler,horizon=1,render=False):
        self.brax_handler=brax_handler
        self.horizon=horizon
        self.render=render
        self.generate_state_from_tensor = self.brax_handler.jit_generate_state_from_tensor
        # print("helper model horizon: ",self.horizon)

        self.rollout = MultiDispatch(self.default_rollout)
        self.rollout.register(jp.ndarray,State)(self.rollout_jp)
        self.rollout.register(torch.Tensor,torch.Tensor)(self.rollout_torch)
    
    # @MultiDispatch
    # def rollout(self, *args):
    #     raise NotImplementedError("rollout not implemented for this type")
    def default_rollout(self, *args):
        # print("args are: ",args)
        # print("arg types are: ",[type(arg) for arg in args])
        # # raise NotImplementedError("rollout not implemented for this type")
        return self.rollout_jp(*args)

    # @self.rollout.register(jp.ndarray,State)
    def rollout_jp(self, U, x_0):
        states=[]
        state=x_0
        if(U.ndim==1):
            U=U[jp.newaxis,...]
        for i in range(self.horizon):
            next_state = self.brax_handler.perform_step(state,U[i],render=self.render)
            torched_state = j2t(next_state.obs)
            torched_action = j2t(U[i])
            combined_state_action = torch.cat((torched_state,torched_action),dim=0)
            states.append(combined_state_action)
            # self.state_list.append(BraxTorchStateMapper(state,U[i]))
            state=next_state

        return states

    # @self.rollout.register(torch.Tensor,torch.Tensor)
    def rollout_torch(self, U, x_0):
        # print(U)
        # print(x_0)
        # print("in torch register function")

        state=self.generate_state_from_tensor(t2j(x_0),9,18)
        return self.rollout(t2j(U),state)

    def rollout_w_grad(self, U,x_0):
        states=[]
        obs_grads=[]
        action_grads=[]
        state=x_0
        full_states=[]
        if(U.ndim==1):
            U=U[jp.newaxis,...]
        for i in range(self.horizon):
            next_state,obs_grad, action_grad = self.brax_handler.perform_step_w_grad(state,U[i],render=self.render)
            full_states.append(next_state)
            torched_state = j2t(next_state.obs)
            torched_action = j2t(U[i])
            # print(torched_action.shape)
            combined_state_action = torch.cat((torched_state,torched_action),dim=0)
            states.append(combined_state_action)

            torched_obs_grad = j2t(obs_grad)
            torched_action_grad = j2t(action_grad)
            obs_grads.append(torched_obs_grad)
            action_grads.append(torched_action_grad)
            state=next_state
        return action_grads,obs_grads, states,full_states
    
    def is_linear(self):
        return False

class BraxModel(DynamicsModel):
    def __init__(self, env_name, backend,rng_seed=0,horizon=1):
        self.brax_handler=BraxHandler(env_name, backend,rng_seed)
        self.horizon=horizon
        # self.state_list=[]

        # Model Versions
        print(f"brax model horizon: {self.horizon}")
        self.one_step_model=BraxHelperModel(self.brax_handler,horizon=1)
        self.horizon_steps_model=BraxHelperModel(self.brax_handler,horizon=self.horizon)
        self.render_model=BraxHelperModel(self.brax_handler,horizon=1,render=True)
        print("horizon steps model horizon: ",self.horizon_steps_model.horizon)

    def is_linear(self):
        return False


# # Test the class
# env_name = 'walker2d_mpc'
# backend = 'generalized'
# # brax_handler = BraxHandler(env_name, backend)
# brax_model = BraxModel(env_name, backend,horizon=20)
# state = brax_model.brax_handler.init_state
# # print(state.obs)

# for _ in range(1):
#     # Pick an action
#     act = jp.ones(brax_model.brax_handler.env.action_size, dtype=float)
#     # print(act.shape)
#     # Perform a step with gradients
#     action_grad,obs_grad,new_state, full_state = brax_model.render_model.rollout_w_grad(act,state)
#     print(full_state[0].pipeline_state.q.shape)
#     print(full_state[0].pipeline_state.qd.shape)
#     print(new_state[0].shape)
#     print(action_grad[0].shape)
#     print(obs_grad[0].shape)
#     # Assign new state (if continually running)
#     state = full_state[0]
    
# brax_model.brax_handler.save_rollout(f"{env_name}_test2.html")

In [3]:
import os
import torch
import argparse
import numpy as np
import matplotlib.pyplot as plt
from mrf_swarm.envs import PointSwarm
from mrf_swarm.controllers.mpc import SteinMPC
from mrf_swarm.sim.robot import LinearPointRobotModel
from mrf_swarm.sim.map import DiffMap
from plotting import draw_belief_traj, draw_paritcles
from utils.point_swarm import make_costs, make_terminal_costs
from mrf_swarm.costs.base_costs import CompositeSumCost

from cascaded_cost_factors import DynamicsPairwiseFactor
from cascaded_cost_factors import CascadedLoopySVBP
from tqdm import tqdm
from outputs_to_gif import plot_as_gif
from utils.walker_cost import make_walking_costs,make_terminal_walking_costs
from cascaded_cost_factors import CascadedMPC, CascadedMPCBrax
import torch_bp.bp as bp
import torch_bp.distributions as dist

from torch_bp.graph import factors, Graph
from torch_bp.util.plotting import plot_dists, plot_graph, plot_particles
from torch_bp.inference.kernels import RBFMedianKernel
from mrf_swarm.factors.trajectory_factors import UnaryRobotTrajectoryFactor
from torch2jax import j2t, t2j


FIG_WIDTH = 6
DT = 0.1
SIM_TIME = 20
HORIZON = 2
NUM_PARTICLES = 3
tensor_kwargs = {"device": 'cpu', "dtype": torch.float}

torch.random.manual_seed(0)

#Initialize state along with brax

env_name = 'walker2d_mpc'
backend = 'generalized'
# brax_handler = BraxHandler(env_name, backend)
brax_model=BraxModel(env_name, backend,horizon=HORIZON)
brax_state = brax_model.brax_handler.init_state
torch_state= j2t(brax_state.obs)
steps = 10
# print(torch_state)

# out_dir = "/home/jacealdr/repos/CascadedCostBrax/repos/images"


x_goal=10
cascading_costs = make_walking_costs(c_pos_x=0.1, c_vel_x=0.25, c_u=0.2, c_term_x=6., c_pos_z=1., c_vel_z=0.5,c_term_z=0.,dim=2, horizon=1, goal_x=x_goal,goal_z=1.5, use_terminal_cost=False,tensor_kwargs=tensor_kwargs)
final_costs = make_terminal_walking_costs(c_term_x=6.,c_term_z=0.,dim=2, horizon=1, goal_x=x_goal,goal_z=1.5, tensor_kwargs=tensor_kwargs)
full_costs = make_walking_costs(c_pos_x=0.1, c_vel_x=0.25, c_u=0.2, c_term_x=6., c_pos_z=1., c_vel_z=0.5,c_term_z=0.,dim=2, horizon=1, goal_x=x_goal,goal_z=1.5, use_terminal_cost=True,tensor_kwargs=tensor_kwargs)

gamma = 1. / np.sqrt(2*2*20)
rbf_kernel = RBFMedianKernel(gamma=gamma)

# TODO fix passing model rollout functions to cascaded mpc

mpc=CascadedMPCBrax(cascading_costs,final_costs, brax_model.one_step_model, brax_model.horizon_steps_model,rbf_kernel,
                brax_state,num_particles=NUM_PARTICLES,horizon=HORIZON,dim=6,
                init_cov=0.5, optim_params={"lr": 0.05}, full_costs=full_costs, one_cost=cascading_costs, term_cost=final_costs, tensor_kwargs=tensor_kwargs)
total_cost=0

POS_TOL = 0.5
VEL_TOL = 0.2

for i in tqdm(range(HORIZON)):
    for _ in range(1):
        mpc.solve(num_iter=1, normalize=True, precompute=False)
        # Do plotting?
    mpc.shift(torch_state)
particle_seed = mpc.get_particles()
mpc.sbp.reset(particle_seed)
combined_cost_eval = CompositeSumCost(costs=cascading_costs, sigma=1,
                                              tensor_kwargs=tensor_kwargs)
print("running mpc")
for i in tqdm(range(steps)):
    mpc.solve(num_iter=1, normalize=True, precompute=False)
    # mpc.shift(state)
    state_action=mpc.get_best_state_action()
    state=state_action[0,0:4]
    action=state_action[0,4:]
    next_state_action=brax_model.render_model.rollout(action[None, ...],state[None, ...])
    mpc.action_history.append(action)
    mpc.state_history.append(state)
    normalized_cost = combined_cost_eval(state_action[0])
    total_cost+=normalized_cost
    # TODO Do save to rollout
    mpc.shift(next_state_action[0,0:4])
    # mpc.shift(state)

    # # if at terminal position and velocity is low, end early
    # if torch.norm(state[0:2] - x_goal) < POS_TOL and torch.norm(state[2:4]) < VEL_TOL:
    #     print("Reached goal at iteration ", i)
    #     break

# running_cost=combined_cost_evalrunning_cost
# print("running initialization")


# for i in range(steps):
#     # Pick an action
#     act = jp.ones(brax_handler.env.action_size, dtype=float)

#     # Perform a step with gradients
#     new_state,obs_grad,action_grad = brax_handler.perform_step_w_grad(state, act,render=True)
   
#     # Assign new state (if continually running)
#     state = new_state

brax model horizon: 2
horizon steps model horizon:  2
Q is: torch.Size([26, 26])
Q is: torch.Size([26, 26])
did first rollout
horizon: 1
torch.Size([1, 20, 1, 6])
(1,20,1,20)
horizon: 1
torch.Size([1, 20, 1, 6])
(1,20,1,20)
horizon: 1
torch.Size([1, 20, 1, 6])
(1,20,1,20)
Done Intiializing MPC class


  0%|          | 0/2 [00:00<?, ?it/s]

start of grad log likelyhood
traj cost grad: tensor([[[-0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
          -0.0000, -0.0000,  0.1737,  0.1238, -0.0000, -0.0000, -0.0000,
          -0.0000, -0.0000, -0.0000, -0.0000,  2.0006,  0.5109,  0.3052,
           0.3105,  0.1301,  0.1668, -0.0897, -0.0584]],

        [[-0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
          -0.0000, -0.0000,  0.4299,  0.3114, -0.0000, -0.0000, -0.0000,
          -0.0000, -0.0000, -0.0000, -0.0000,  2.0017,  0.5146,  0.1432,
           0.5030,  0.0155,  0.3327,  0.0100,  0.0184]]])
traj cost grad shape: torch.Size([2, 1, 26])
log px: tensor([-10.4168, -10.7979])
log px shape: torch.Size([2])
grad_log_px: tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0673,  0.0323, -0.0267,  0.0768,
          0.0108, -0.0171],
     

: 