In [1]:
import jax.numpy as jnp

def pose_cross_map(pose) -> jnp.ndarray:
    R = pose[0:3, 0:3]
    g_cross = jnp.zeros((6,6))
    g_cross = g_cross.at[0:3, 0:3].set(R)
    g_cross = g_cross.at[3:6, 3:6].set(R)
    return g_cross

In [2]:
mat = jnp.ones((6,6))
mat

Array([[1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1.]], dtype=float32)

In [3]:
pose_cross_map(mat)

Array([[1., 1., 1., 0., 0., 0.],
       [1., 1., 1., 0., 0., 0.],
       [1., 1., 1., 0., 0., 0.],
       [0., 0., 0., 1., 1., 1.],
       [0., 0., 0., 1., 1., 1.],
       [0., 0., 0., 1., 1., 1.]], dtype=float32)

In [4]:
def momenta_cross_map(pose, mass_matrix, jacobian, joint_speeds) -> jnp.ndarray:
    # TODO - Finish the conjugate momenta cross map
    if joint_speeds.shape != (6,1):
        joint_speeds = joint_speeds.T
    j_inv = jnp.linalg.pinv(jacobian)
    m_bar = j_inv.T @ mass_matrix @ j_inv
    twist = jacobian @ joint_speeds
    momenta = m_bar @ twist
    momenta_cross = jnp.zeros((6,6))

def hat_map(x: jnp.ndarray) -> jnp.ndarray:
    """
    Computes the skew-symmetric map from R^3 to SE(3). 

    Args:
        x (jnp.ndarray): Input vector in R^3

    Returns:
        jnp.ndarray: Returns hat map of the 
    """
    assert x.shape == (1,3)
    x_hat = jnp.array([
        [0, -x[0][2], x[0][1]],
        [x[0][2], 0, -x[0][0]],
        [-x[0][1], x[0][0], 0]
    ])
    return x_hat

In [5]:
vec = jnp.array([[1, 2, 3, 4, 5, 6]])
vec[0:3]

Array([[1, 2, 3, 4, 5, 6]], dtype=int32)

In [6]:
hat_map(vec[0][3:6].reshape((1,3)))

Array([[ 0, -6,  5],
       [ 6,  0, -4],
       [-5,  4,  0]], dtype=int32)

In [7]:
m_cross = jnp.zeros((6,6))
m_cross

Array([[0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.]], dtype=float32)

In [8]:
m_cross = m_cross.at[0:3,3:6].set(hat_map(vec[0][0:3].reshape((1,3))))
m_cross = m_cross.at[3:6,0:3].set(hat_map(vec[0][0:3].reshape((1,3))))
m_cross = m_cross.at[3:6,3:6].set(hat_map(vec[0][3:6].reshape((1,3))))
m_cross

Array([[ 0.,  0.,  0.,  0., -3.,  2.],
       [ 0.,  0.,  0.,  3.,  0., -1.],
       [ 0.,  0.,  0., -2.,  1.,  0.],
       [ 0., -3.,  2.,  0., -6.,  5.],
       [ 3.,  0., -1.,  6.,  0., -4.],
       [-2.,  1.,  0., -5.,  4.,  0.]], dtype=float32)

In [9]:
m_cross.shape

(6, 6)

In [10]:
arr = jnp.array([
    [1,2,3,4],
    [1,2,3,4],
    [1,2,3,4],
    [1,2,3,4]])
R = arr[0:3,0:3]
p = arr[0:3,-1].reshape((3,1))

In [11]:
R, p.shape

(Array([[1, 2, 3],
        [1, 2, 3],
        [1, 2, 3]], dtype=int32),
 (3, 1))

In [12]:
error = jnp.array([[1,2,3,4,5,6]]).T
error[3:6]

Array([[4],
       [5],
       [6]], dtype=int32)

In [13]:
0.5*error

Array([[0.5],
       [1. ],
       [1.5],
       [2. ],
       [2.5],
       [3. ]], dtype=float32, weak_type=True)

In [14]:
error.at[0].set(10)

Array([[10],
       [ 2],
       [ 3],
       [ 4],
       [ 5],
       [ 6]], dtype=int32)

In [18]:
vec = jnp.ones((6,1))
vec

Array([[1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [1.]], dtype=float32)

In [19]:
vec[:,0:3].shape

(6, 1)

In [21]:
for i in range(vec.shape[0]-2, -1,-1):
    print(i)

4
3
2
1
0
