In [1]:
from hct.envs.low_level_env import LowLevelEnv
from hct.envs.ant import Ant
from hct.envs.observer import Observer
from hct.envs.goal import GoalConstructor

import functools
import jax
import os

from datetime import datetime
from jax import numpy as jp
import matplotlib.pyplot as plt

from IPython.display import HTML, clear_output

import brax

from brax.envs.humanoid import Humanoid
from brax.envs.humanoidstandup import HumanoidStandup
import flax
from brax import envs
from brax.io import model
from brax.io import json
from brax.io import html
from hct import train as ppo
from brax.training.agents.sac import train as sac
from brax.base import Transform, Motion
from brax.math import euler_to_quat
from absl import logging

import getpass
import socket

from brax.kinematics import world_to_joint, inverse, forward




In [2]:
env = Ant()

jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)
jit_inference_fn = jax.random.uniform

rollout = []
rng = jax.random.PRNGKey(seed=1)
state = jit_env_reset(rng=rng)
for _ in range(1000):
  rollout.append(state.pipeline_state)
  act_rng, rng = jax.random.split(rng)
  act = jit_inference_fn(act_rng, shape=(8,), minval = -1,maxval=1)
  state = jit_env_step(state, act)

HTML(html.render(env.sys.replace(dt=env.dt), rollout))


In [None]:
kwargs = {'debug': True}
env = Ant(**kwargs)

jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)
jit_inference_fn = jax.random.uniform

rollout = []
rng = jax.random.PRNGKey(seed=1)
state = jit_env_reset(rng=rng)
for _ in range(1000):
  rollout.append(state.pipeline_state)
  act_rng, rng = jax.random.split(rng)
  act = jit_inference_fn(act_rng, shape=(8,), minval = -1,maxval=1)
  state = jit_env_step(state, act)

HTML(html.render(env.sys.replace(dt=env.dt), rollout))

In [3]:
print(html.render(env.sys.replace(dt=env.dt), rollout))


<!DOCTYPE html>
<html>

  <head>
    <title>Brax visualizer</title>
    <link rel="shortcut icon" type="image/x-icon" href="/favicon.ico">
  </head>

  <style>
    body {
      margin: 0;
      padding: 0;
    }

    #brax-viewer {
      height: 480px;
      margin: 0;
      padding: 0;
    }
  </style>
  <script async src="https://unpkg.com/es-module-shims@1.6.3/dist/es-module-shims.js"></script>

  <script type="importmap">
    {
      "imports": {
        "three": "https://unpkg.com/three@0.150.1/build/three.module.js",
        "three/addons/": "https://unpkg.com/three@0.150.1/examples/jsm/",
        "lilgui": "https://cdn.jsdelivr.net/npm/lil-gui@0.18.0/+esm",
        "viewer": "https://cdn.jsdelivr.net/gh/google/brax@v0.9.1/brax/visualizer/js/viewer.js"
      }
    }
  </script>

  <script type="application/javascript">
  var system = {"dt": 0.04999999701976776, "gravity": [0.0, 0.0, -9.8100004196167], "viscosity": 0.0, "density": 0.0, "link": {"transform": {"pos": [[0.0, 0.0, 0.0

In [4]:
s = state.pipeline_state
from brax.geometry import contact


f = lambda y: contact(env.sys, y)
x = [f(s.x) for s in rollout]
print(x)


[Contact(pos=Array([[ 0.        ,  0.        ,  0.15      ],
       [ 0.4       ,  0.4       ,  0.23500001],
       [ 0.61612093,  0.61612093, -0.00300393],
       [-0.4       ,  0.4       ,  0.23500001],
       [-0.61612093,  0.61612093, -0.00300393],
       [-0.4       , -0.4       ,  0.23500001],
       [-0.61612093, -0.61612093, -0.00300393],
       [ 0.4       , -0.4       ,  0.23500001],
       [ 0.61612093, -0.61612093, -0.00300393]], dtype=float32), normal=Array([[0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.]], dtype=float32), penetration=Array([-0.3       , -0.47000003,  0.00600787, -0.47000003,  0.00600787,
       -0.47000003,  0.00600787, -0.47000003,  0.00600787], dtype=float32), friction=Array([1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32), elasticity=Array([0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32), solver_params=Array([[2.0

In [5]:
for i in x:
    print(i.pos)

[[ 0.          0.          0.15      ]
 [ 0.4         0.4         0.23500001]
 [ 0.61612093  0.61612093 -0.00300393]
 [-0.4         0.4         0.23500001]
 [-0.61612093  0.61612093 -0.00300393]
 [-0.4        -0.4         0.23500001]
 [-0.61612093 -0.61612093 -0.00300393]
 [ 0.4        -0.4         0.23500001]
 [ 0.61612093 -0.61612093 -0.00300393]]
[[-1.2581423e-03 -6.7540961e-03  1.6022289e-01]
 [ 3.5754704e-01  4.2954671e-01  2.4016196e-01]
 [ 5.2488565e-01  6.5746725e-01 -4.8125461e-03]
 [-3.7900928e-01  4.1073301e-01  2.4504125e-01]
 [-5.3413022e-01  6.0455954e-01 -9.1087595e-03]
 [-4.0272155e-01 -4.0405205e-01  2.5031447e-01]
 [-5.9406185e-01 -5.7505482e-01 -1.7504171e-03]
 [ 3.9768755e-01 -4.0686163e-01  2.4515831e-01]
 [ 5.7906222e-01 -6.2058043e-01 -5.1809102e-04]]
[[ 1.2500003e-02 -1.6772561e-02  1.8308216e-01]
 [ 3.1810039e-01  4.5786375e-01  2.6954532e-01]
 [ 3.7814853e-01  6.1616683e-01 -3.3164769e-04]
 [-3.6316085e-01  3.9156890e-01  2.8072065e-01]
 [-4.3259227e-01  4.881

In [39]:
arr = jp.array([2,4,6,8])
for i in range(10): 
    rng, rng1, rng2 = jax.random.split(rng, 3)
    n = jax.random.choice(rng2, jp.array(range(0,len(arr)+1)))
    r = jax.random.choice(rng1, arr, shape=(n,), replace = False)
    r = jp.sort(r)
    print(r)

print(arr[r])

[]
[2]
[4 6 8]
[2 4 6 8]
[2 4 6]
[2 4 6 8]
[4 8]
[]
[4 6 8]
[6 8]
[8 8]


In [None]:
import numpy as np
from brax.base import State
env = Ant()
goalconstructor = GoalConstructor(env.sys, configs = {})

q = env.sys.init_q 
qd = jp.zeros(env.sys.qd_size())
rng = jax.random.PRNGKey(seed=np.random.randint(0,1000))
@jax.jit
def f(rng):
    pipeline_state = env.pipeline_init(q, qd)
    
    # Sample and set goal
    goal = goalconstructor.sample_goal(rng, pipeline_state)

    """print(goal.q)
    print(goalconstructor.dof_limits)"""

    state = State(
        q = goal.q,
        qd = goal.qd,
        x = goal.x,
        xd = goal.xd,
        contact=None
    )
    return state

rollout = []
for r in range(10):
    rng, rng1 = jax.random.split(rng)
    rollout.append(f(rng1))

display(HTML(html.render(env.sys.replace(dt=env.dt), rollout)))

#self.goalconstructor.sample_goal()

In [None]:
display(HTML(html.render(env.sys.replace(dt=env.dt), rollout)))
