In [1]:
import jax
from jax import numpy as jnp
import mujoco
from mujoco import mjx
import mediapy as media
from lib import calculate_fk, calculate_ik_step, calculate_ik

In [2]:
xml = """
<mujoco model="two_joint_arm">
    <compiler angle="degree" inertiafromgeom="true"/>
    <option timestep="0.01" gravity="0 0 -9.81"/>
    <worldbody>
        <light diffuse=".5 .5 .5" pos="0 0 3" dir="0 0 -1"/>
        <geom name="floor" pos="0 0 -0.1" size="1 1 0.1" type="plane" rgba="0.8 0.9 0.8 1"/>
        <body name="base" pos="0 0 0">
            <geom name="base" type="cylinder" size="0.05 0.02" rgba="0.2 0.2 0.2 1"/>
            <body name="upper_arm" pos="0 0 0.02">
                <joint name="shoulder" type="hinge" axis="0 0 1" range="-90 90"/>
                <geom name="upper_arm" type="capsule" fromto="0 0 0 0.2 0 0" size="0.02" rgba="0.7 0.7 0 1"/>
                <body name="lower_arm" pos="0.2 0 0">
                    <joint name="elbow" type="hinge" axis="0 1 0" range="-90 90"/>
                    <geom name="lower_arm" type="capsule" fromto="0 0 0 0.2 0 0" size="0.02" rgba="0 0.7 0.7 1"/>
                    <site name="end_effector" pos="0.2 0 0" size="0.01" rgba="1 0 0 1"/>
                </body>
            </body>
        </body>
        <body name="target" pos="0.3 0.1 0.2">
            <geom name="target" type="sphere" size="0.02" rgba="2 0 0 0.5"/>
        </body>
    </worldbody>
    <actuator>
        <motor joint="shoulder" ctrlrange="-1 1" ctrllimited="true"/>
        <motor joint="elbow" ctrlrange="-1 1" ctrllimited="true"/>
    </actuator>
    <sensor>
        <touch name="touch_sensor" site="end_effector"/>
    </sensor>
</mujoco>
"""

In [3]:
mj_model = mujoco.MjModel.from_xml_string(xml)
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model)

mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)

print(mj_data.qpos, type(mj_data.qpos))
print(mjx_data.qpos, type(mjx_data.qpos), mjx_data.qpos.devices())

[0. 0.] <class 'numpy.ndarray'>
[0. 0.] <class 'jaxlib.xla_extension.ArrayImpl'> {CpuDevice(id=0)}


In [4]:
def is_in_range(x, lower, upper):
    return jnp.logical_and(x >= lower, x <= upper)

def touch_target(z_pos, target_z):
    if is_in_range(jnp.array(z_pos), 0.185, 0.188):
        return jnp.array([0, -1])
    else:
        return jnp.array([1, 0])

def cost(target_pos: jax.Array, current_pos: jax.Array) -> float:
    return jnp.power((target_pos - current_pos), 2)

In [5]:
def radians_to_normalized(angle, angle_min, angle_max):
    return 2 * (angle - angle_min) / (angle_max - angle_min) - 1
  

In [6]:
import numpy as np

# enable joint visualization option:
scene_option = mujoco.MjvOption()
scene_option.flags[mujoco.mjtVisFlag.mjVIS_JOINT] = True


mujoco.mj_resetData(mj_model, mj_data)
min_angle = -jnp.pi / 2
max_angle = jnp.pi / 2

duration = 10 # (seconds)
framerate = 60  # (Hz)
r = 0.2
desired_position = jnp.array([0.3, 0.1, 0.2])
initial_q = np.random.uniform(-1, 1, size=(2,))
ik = calculate_ik(initial_q=jnp.array(initial_q), length=r, desired_position=desired_position, alpha=1e-4)
normalized_ik = jnp.array([jnp.clip(a=ik.at[0].get(), a_min=-1, a_max=1), jnp.clip(a=ik.at[1].get(), a_max=1, a_min=-1),]) 
print(ik)
print(f'Desired joint config: {normalized_ik}')
print(f'Initial guess: {initial_q}')
frames = []

mj_data.ctrl = normalized_ik
print(normalized_ik)
while mj_data.time < duration:
  mujoco.mj_step(mj_model, mj_data)
  
  if len(frames) < mj_data.time * framerate:
    renderer.update_scene(mj_data, scene_option=scene_option)
    pixels = renderer.render()
    frames.append(pixels)
print("Sensor data")
# Simulate and display video.
media.show_video(frames, fps=framerate)

0.30122453
0.30137274
0.301521
0.3016693
0.30181772
0.3019662
0.30211475
0.30226338
0.30241203
0.3025608
0.30270964
0.30285856
0.30300757
0.30315658
0.30330572
0.30345497
0.30360422
0.30375355
0.303903
0.3040525
0.30420205
0.30435172
0.30450144
0.3046512
0.30480105
0.304951
0.305101
0.30525106
0.3054012
0.30555144
0.30570176
0.30585212
0.3060026
0.3061531
0.30630365
0.30645436
0.30660507
0.3067559
0.30690673
0.3070577
0.30720875
0.30735984
0.30751106
0.30766228
0.3078136
0.30796498
0.30811644
0.308268
0.30841962
0.3085713
0.3087231
0.30887488
0.3090268
0.3091788
0.30933085
0.30948296
0.3096352
0.30978748
0.30993983
0.31009224
0.31024474
0.31039733
0.31054994
0.31070265
0.31085545
0.31100833
0.3111613
0.31131428
0.31146735
0.31162056
0.31177378
0.3119271
0.31208047
0.31223392
0.31238747
0.3125411
0.31269476
0.31284848
0.31300232
0.31315625
0.31331018
0.31346425
0.3136184
0.3137726
0.31392688
0.31408116
0.3142356
0.31439012
0.31454468
0.31469932
0.31485403
0.31500882
0.31516364
0.3153186

0
This browser does not support the video tag.


In [8]:
arr = jnp.array([[1,2,3], [1, 23, 4]])
arr

Array([[ 1,  2,  3],
       [ 1, 23,  4]], dtype=int32)

In [12]:
arr.at[0].get().sum()

Array(6, dtype=int32)