In [1]:
import torch

In [6]:
class term():
    joint_ids = ['FL_hip_joint', 'FR_hip_joint', 'RL_hip_joint', 'RR_hip_joint', 'FL_thigh_joint', 'FR_thigh_joint', 'RL_thigh_joint', 'RR_thigh_joint', 'FL_calf_joint', 'FR_calf_joint', 'RL_calf_joint', 'RR_calf_joint']#['RR_hip', 'RR_joint', 'RR_calf']
    num_envs = 64
    _num_legs = 4
    _prevision_horizon = 10

    def __init__(self):
        self._num_joints = len(self.joint_ids)
        self.f = torch.zeros(self.num_envs, self._num_legs)
        self.d = torch.zeros(self.num_envs, self._num_legs)
        self.p = torch.zeros(self.num_envs, self._num_legs, self._prevision_horizon)
        self.F = torch.zeros(self.num_envs, self._num_legs, self._prevision_horizon)
        self.z = [self.f, self.d, self.p, self.F]

        # create tensors for raw and processed actions
        self._raw_actions = torch.zeros(self.num_envs, self.action_dim2)
        self._processed_actions = torch.zeros_like(self.raw_actions)

    @property
    def action_dim(self) -> int:
        return self._num_joints
    
    @property
    def action_dim2(self) -> int:
        return self.f.shape[1:].numel() + self.d.shape[1:].numel() + self.p.shape[1:].numel() + self.F.shape[1:].numel()
    
    @property
    def action_dim3(self) -> int:
        return sum(variable.shape[1:].numel() for variable in self.z)
    
    @property
    def raw_actions(self) -> torch.Tensor:
        return self._raw_actions

    @property
    def processed_actions(self) -> torch.Tensor:
        return self._processed_actions

In [8]:
term1 = term()

In [10]:
term1.action_dim2

88

In [9]:
term1.action_dim3

88

In [14]:
term1.raw_actions.shape

torch.Size([64, 88])

In [77]:
term1.z[1]

tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])

In [50]:
term1.num_envs

4

In [68]:
a = torch.zeros(term1.num_envs, term1.action_dim,2)
print('tensor :',a)
print('shape :', a.shape[1:].numel())

print(a.flatten().shape)

tensor : tensor([[[0., 0.],
         [0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.],
         [0., 0.]]])
shape : 6
torch.Size([24])


In [None]:
from __future__ import annotations

from omni.isaac.orbit.assets import AssetBase
from omni.isaac.orbit.assets.articulation import Articulation
from omni.isaac.orbit_tasks.locomotion.model_based.model_based_env_cfg import LocomotionModelBasedEnvCfg
from omni.isaac.orbit_tasks.locomotion.model_based.config.unitree_aliengo.aliengo_base_env_cfg import UnitreeAliengoBaseEnvCfg

In [4]:
import jax.numpy as jnp
import jax
seed = 42
key = jax.random.key(seed)

In [35]:
print(f'available devices: {torch.cuda.device_count()}')
print(f'current device: { torch.cuda.current_device()}')
torch.cuda.get_device_name(0)

available devices: 1
current device: 0


'NVIDIA GeForce RTX 3060'

In [40]:
output_torques = (torch.rand(term1.num_envs, term1._num_joints, device='cuda') * 80) - 40
print('shape : ',output_torques.shape)
print('device : ',output_torques.device)

shape :  torch.Size([64, 12])
device :  cuda:0


In [7]:
output_torques_jax = jax.random.normal(key=key, shape=output_torques.shape)

In [13]:
print('--- Torch ---')
print('Shape : ', output_torques.shape)
print('Type : ', output_torques.type())
print('Type : ', type(output_torques))

print('')
print('--- Jax ---')
print('Shape : ', output_torques_jax.shape)
print('Type : ', type(output_torques_jax))

--- Torch ---
Shape :  torch.Size([64, 12])
Type :  torch.FloatTensor
Type :  <class 'torch.Tensor'>

--- Jax ---
Shape :  (64, 12)
Type :  <class 'jaxlib.xla_extension.ArrayImpl'>


In [15]:
output_torques_jax.std()

Array(1.0063325, dtype=float32)

In [17]:
import jax
import jax.dlpack
import torch
import torch.utils.dlpack

def jax_to_torch(x):
    return torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(x))
def torch_to_jax(x):
    return jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(x))

a = torch.tensor([1,2,3]).cuda()
a_jax = torch_to_jax(a)
print(a_jax)

[1 2 3]


In [18]:
a.device

device(type='cuda', index=0)

In [28]:
a_jax.devices()

{cuda(id=0)}

In [42]:
output_torques = (torch.rand(term1.num_envs, term1._num_joints, device='cuda') * 80) - 40
print('shape : ',output_torques.shape)
print('device : ',output_torques.device)
print('Type : ', type(output_torques))

shape :  torch.Size([64, 12])
device :  cuda:0
Type :  <class 'torch.Tensor'>


In [45]:
output_torques_jax = torch_to_jax(output_torques)
print('Shape : ', output_torques_jax.shape)
print('device : ',output_torques_jax.devices())
print('Type : ', type(output_torques_jax))

Shape :  (64, 12)
device :  {cuda(id=0)}
Type :  <class 'jaxlib.xla_extension.ArrayImpl'>


In [25]:
def alo() -> tuple[int, int, str]:
    a = 2
    b = 3
    c = 4
    return a, b, str(c)

def alo2():
    a = 2
    b = 3
    c = 4
    return a, b, str(c)

In [24]:
alo()

(2, 3, '4')

In [26]:
alo2()

(2, 3, '4')

In [27]:
print(type(alo()))
print(type(alo2()))

<class 'tuple'>
<class 'tuple'>


In [10]:
a = tuple[2,3,4]
alo()

d, f, e = alo()
type(alo())

tuple

In [31]:
import torch

a = torch.tensor([0, 1.21, 2])
b = torch.tensor([True, True, False])

shape = [2,3]
a = torch.rand(shape)
b = torch.empty(shape, dtype=torch.bool).bernoulli(0.5)


print(a.dtype)
print(b.dtype)
print(a)
print(b)

a*b

torch.float32
torch.bool
tensor([[0.5317, 0.4781, 0.3271],
        [0.3938, 0.3433, 0.9002]])
tensor([[ True, False,  True],
        [False,  True, False]])


tensor([[0.5317, 0.0000, 0.3271],
        [0.0000, 0.3433, 0.0000]])

In [43]:
import torch

# Assuming you have a tensor of shape (batch_size, num_legs)
tensor = torch.randn(5, 4)  # Example tensor with shape (5, 4)

# Define the number of joints per leg
number_of_joint_per_leg = 3

# Modify the tensor to shape (batch_size, num_legs, number_of_joint_per_leg)
modified_tensor = torch.reshape(tensor, (tensor.shape[0], tensor.shape[1], number_of_joint_per_leg))

# Check the shape of the modified tensor
print("Modified tensor shape:", modified_tensor.shape)

RuntimeError: shape '[5, 4, 3]' is invalid for input of size 20

In [134]:
import torch
import time

# Create some tensors for demonstration
T_shape = [4096,4,3]
c_shape = [4096,4]
T_1 = torch.rand(T_shape).cuda()
T_2 = torch.rand(T_shape).cuda()
c = torch.empty(c_shape, dtype=torch.bool).bernoulli(0.5).cuda()


# Create CUDA events
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

# Record start event
start_event.record()

# Example operation (e.g., matrix multiplication)

result = (T_1 * c.unsqueeze(-1)) + (T_2 * (~c).unsqueeze(-1))

# Record end event
end_event.record()

# Wait for computations to finish
torch.cuda.synchronize()

# Calculate elapsed time
elapsed_time = start_event.elapsed_time(end_event) / 1000  # Convert to seconds
print("Time taken:", elapsed_time, "seconds")

Time taken: 0.00016944000124931335 seconds


In [87]:
import jax
def custom_operation(T_1, T_2, c_star):
    # Element-wise multiplication with c_star and its complement
    term1 = T_1 * c_star[..., None]
    term2 = T_2 * (~c_star)[..., None]
    
    # Sum the terms along the joint dimension
    T = term1 + term2
    
    return T

# Example usage
batch_size = 3
num_legs = 4
num_of_joints_per_leg = 5

# Random tensors for T_1, T_2, and c_star
T_1 = jax.random.normal(jax.random.PRNGKey(0), (batch_size, num_legs, num_of_joints_per_leg))
T_2 = jax.random.normal(jax.random.PRNGKey(1), (batch_size, num_legs, num_of_joints_per_leg))
c_star = jax.random.randint(jax.random.PRNGKey(2), (batch_size, num_legs), 0, 2)

# Perform custom operation
T = custom_operation(T_1, T_2, c_star)

print(T.shape)  # Output shape should be (batch_size, num_legs, num_of_joints_per_leg)

(3, 4, 5)


In [47]:
import jax
import jax.numpy as jnp
from jax import jit

@jit
def custom_operation(T_1, T_2, c_star):
    # Element-wise multiplication with c_star and its complement
    term1 = T_1 * c_star[..., None]
    term2 = T_2 * (~c_star)[..., None]
    
    # Sum the terms along the joint dimension
    T = term1 + term2
    
    return T

# Example usage
batch_size = 3
num_legs = 4
num_of_joints_per_leg = 5

# Random tensors for T_1, T_2, and c_star
T_1 = jax.random.normal(jax.random.PRNGKey(0), (batch_size, num_legs, num_of_joints_per_leg))
T_2 = jax.random.normal(jax.random.PRNGKey(1), (batch_size, num_legs, num_of_joints_per_leg))
c_star = jax.random.randint(jax.random.PRNGKey(2), (batch_size, num_legs), 0, 2)

# Move tensors to GPU
T_1 = jax.device_put(T_1, jax.devices('gpu')[0])
T_2 = jax.device_put(T_2, jax.devices('gpu')[0])
c_star = jax.device_put(c_star, jax.devices('gpu')[0])

# Perform custom operation
T = custom_operation(T_1, T_2, c_star)

print(T.shape)  # Output shape should be (batch_size, num_legs, num_of_joints_per_leg)

(3, 4, 5)


In [86]:
# Create some tensors for demonstration
batch_size = 4096
num_legs = 4
num_of_joints_per_leg = 3
T_1 = jax.random.normal(jax.random.PRNGKey(0), (batch_size, num_legs, num_of_joints_per_leg))
T_2 = jax.random.normal(jax.random.PRNGKey(1), (batch_size, num_legs, num_of_joints_per_leg))
c_star = jax.random.randint(jax.random.PRNGKey(2), (batch_size, num_legs), 0, 2)

# Create CUDA events
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

# Record start event
start_event.record()

# Example operation (e.g., matrix multiplication)
T = custom_operation(T_1, T_2, c_star)

# result = (T_1 * c.unsqueeze(-1)) + (T_2 * (~c).unsqueeze(-1))

# Record end event
end_event.record()

# Wait for computations to finish
torch.cuda.synchronize()

# Calculate elapsed time
elapsed_time = start_event.elapsed_time(end_event) / 1000  # Convert to seconds
print("Time taken:", elapsed_time, "seconds")

Time taken: 0.00010966400057077407 seconds
