<a href="https://colab.research.google.com/github/jamesjbustos/init-build/blob/main/robot_rl/RL_Robot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<h1 style="text-align:center">RL Robot</h1>

---

<p style="text-align:center">
    <strong>Author:</strong> <a href="https://github.com/jamesjbustos">James Bustos</a>
</p>
<p style="text-align:center">
    <strong>Reference:</strong><br>
    <em>Google Brax</em>. GitHub repository. <a href="https://github.com/google/brax">https://github.com/google/brax</a>
</p>
<p style="text-align:center">
    <strong>Date & Topics:</strong><br>
    <time">November 14, 2024</time> | <span>Google Brax, Mujco, and RL</span>
    <br><br>  
</p>

## Prerequisites

In [None]:
import functools
import jax
import os

from datetime import datetime
from jax import numpy as jp
import matplotlib.pyplot as plt

from IPython.display import HTML, clear_output

try:
  import brax
except ImportError:
  !pip install git+https://github.com/google/brax.git@main
  clear_output()
  import brax

import flax
from brax import envs
from brax.io import model
from brax.io import json
from brax.io import html
from brax.training.agents.ppo import train as ppo
from brax.training.agents.sac import train as sac

if 'COLAB_TPU_ADDR' in os.environ:
  from jax.tools import colab_tpu
  colab_tpu.setup_tpu()

## Setting up Brax environment

In [None]:
from brax.generalized import pipeline
from brax.io import mjcf
m = mjcf.loads(
    """<mujoco>
         <option timestep="0.005"/>
         <worldbody>
           <body pos="0 0 3">
             <joint type="free"/>
             <geom size="1 1 1" type="box"/>
           </body>
           <geom size="40 40 40" type="plane"/>
         </worldbody>
       </mujoco>
  """)

jit_env_reset = jax.jit(pipeline.init)
jit_env_step = jax.jit(pipeline.step)

rollout = []
rng = jax.random.PRNGKey(seed=1)
state = jit_env_reset(m, m.init_q, jp.zeros(m.qd_size()))
for i in range(500):
  rollout.append(state)
  act = 10 * jp.sin(i / 100) * jp.ones(m.act_size())
  state = jit_env_step(m, state, act)

HTML(html.render(m, rollout))


In [2]:
!git clone https://github.com/jamesjbustos/init-build.git
! cd init-build

Cloning into 'init-build'...
remote: Enumerating objects: 239, done.[K
remote: Counting objects: 100% (186/186), done.[K
remote: Compressing objects: 100% (149/149), done.[K
remote: Total 239 (delta 70), reused 137 (delta 28), pack-reused 53 (from 1)[K
Receiving objects: 100% (239/239), 110.70 MiB | 30.54 MiB/s, done.
Resolving deltas: 100% (82/82), done.
Updating files: 100% (39/39), done.


In [3]:
import os

In [3]:
os.chdir("init-build/robot_rl")