In [6]:
#@title Colab setup and imports

from matplotlib.lines import Line2D
from matplotlib.patches import Circle
import matplotlib.pyplot as plt
import numpy as np
import brax
from IPython.display import HTML, Image 

from brax import envs
from brax import jumpy as jp
from brax.envs import to_torch
from brax.io import html
from brax.io import image
from brax.io import mesh
import jax
from jax import numpy as jnp

## Brax Config

Here's a brax config that defines a bouncy ball:

In [31]:
#@title A bouncy ball scene
bouncy_ball = brax.Config(dt=0.05, substeps=20, dynamics_mode='pbd')

# ground is a frozen (immovable) infinite plane
ground = bouncy_ball.bodies.add(name='ground')
ground.frozen.all = True
plane = ground.colliders.add().plane
plane.SetInParent()  # for setting an empty oneof

# ball weighs 1kg, has equal rotational inertia along all axes, is 1m long, and
# has an initial rotation of identity (w=1,x=0,y=0,z=0) quaternion
ball = bouncy_ball.bodies.add(name='ball', mass=1)
cap = ball.colliders.add().mesh
cap.name = "Cylinder"
cap.scale = 0.05
bouncy_ball.mesh_geometries.add(name="Cylinder", path="../brax/tests/testdata/cylinder.stl")

# gravity is -9.8 m/s^2 in z dimension
bouncy_ball.gravity.z = -9.8


We visualize this system config like so:

## Brax State

$\text{QP}$, brax's dynamic state, is a structure with the following fields:

In [37]:
qp = brax.QP(
    # position of each body in 3d (z is up, right-hand coordinates)
    pos = np.array([[0., 0., 0.],       # ground
                    [0., 0., 3.]]),     # ball is 3m up in the air
    # velocity of each body in 3d
    vel = np.array([[0., 0., 0.],       # ground
                    [0., 0., 0.]]),     # ball
    # rotation about center of body, as a quaternion (w, x, y, z)
    rot = np.array([[1., 0., 0., 0.],   # ground
                    [1., 0., 0., 0.]]), # ball
    # angular velocity about center of body in 3d
    ang = np.array([[0., 0., 0.],       # ground
                    [0., 0., 0.]])      # ball

)

In [39]:
sys = brax.System(bouncy_ball)
rollout = []
new_qp = qp
for i in range(25):
  print("Step {}".format(i))
  new_qp, _ = sys.step(new_qp, [])
  rollout.append(new_qp)

HTML(html.render(sys, rollout))

Step 0
Step 1
Step 2
Step 3
Step 4
Step 5
Step 6
Step 7
Step 8
Step 9
Step 10
Step 11
Step 12
Step 13
Step 14
Step 15
Step 16
Step 17
Step 18
Step 19
Step 20
Step 21
Step 22
Step 23
Step 24
