In [1]:
from isaac_victor_envs.utils import get_assets_dir
from isaac_victor_envs.tasks.allegro import AllegroScrewdriverTurningEnv
# from isaac_victor_envs.tasks.allegro_ros import RosAllegroValveTurningEnv

import numpy as np
import pickle
from tqdm.notebook import tqdm

import scipy
import torch
import time
import copy
import yaml
import pathlib
from functools import partial
import sys

import pytorch_volumetric as pv
import pytorch_kinematics as pk
import pytorch_kinematics.transforms as tf
from torch.func import vmap, jacrev, hessian, jacfwd
# import pytorch3d.transforms as tf

import matplotlib.pyplot as plt
from ccai.utils.allegro_utils import *
# from allegro_valve_roll import AllegroValveTurning, AllegroContactProblem, PositionControlConstrainedSVGDMPC, \
#    add_trajectories, add_trajectories_hardware

from ccai.allegro_contact import AllegroManipulationProblem, PositionControlConstrainedSVGDMPC, add_trajectories, \
    add_trajectories_hardware
from ccai.allegro_screwdriver_problem_diffusion import AllegroScrewdriverDiff
from scipy.spatial.transform import Rotation as R

# from ccai.mpc.ipopt import IpoptMPC
# from ccai.problem import IpoptProblem
from ccai.models.trajectory_samplers import TrajectorySampler

import matplotlib
import matplotlib.pyplot as plt

from collections import defaultdict

%load_ext autoreload
%autoreload 2

%matplotlib inline

Importing module 'gym_38' (/home/abhinav/Downloads/IsaacGym_Preview_4_Package/isaacgym/python/isaacgym/_bindings/linux-x86_64/gym_38.so)
Setting GYM_USD_PLUG_INFO_PATH to /home/abhinav/Downloads/IsaacGym_Preview_4_Package/isaacgym/python/isaacgym/_bindings/linux-x86_64/usd/plugInfo.json
PyTorch version 2.4.0+cu121
Device count 2
/home/abhinav/Downloads/IsaacGym_Preview_4_Package/isaacgym/python/isaacgym/_bindings/src/gymtorch


Using /home/abhinav/.cache/torch_extensions/py38_cu121 as PyTorch extensions root...
Emitting ninja build file /home/abhinav/.cache/torch_extensions/py38_cu121/gymtorch/build.ninja...
Building extension module gymtorch...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


ninja: no work to do.


Loading extension module gymtorch...


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
class AllegroScrewdriver(AllegroManipulationProblem):
    def __init__(self,
                 start,
                 goal,
                 T,
                 chain,
                 object_location,
                 object_type,
                 world_trans,
                 object_asset_pos,
                 regrasp_fingers=[],
                 contact_fingers=['index', 'middle', 'ring', 'thumb'],
                 friction_coefficient=0.95,
                 obj_dof=1,
                 obj_ori_rep='euler',
                 obj_joint_dim=0,
                 optimize_force=False,
                 turn=False,
                 obj_gravity=False,
                 min_force_dict=None,
                 device='cuda:0',
                 full_dof_goal=False, **kwargs):
        self.obj_mass = 0.1
        self.obj_dof_type = None
        if obj_dof == 3:
            object_link_name = 'screwdriver_body'
        elif obj_dof == 1:
            object_link_name = 'valve'
        elif obj_dof == 6:
            object_link_name = 'card'
        self.obj_link_name = object_link_name
        super(AllegroScrewdriver, self).__init__(start=start, goal=goal, T=T, chain=chain,
                                                 object_location=object_location,
                                                 object_type=object_type, world_trans=world_trans,
                                                 object_asset_pos=object_asset_pos,
                                                 regrasp_fingers=regrasp_fingers,
                                                 contact_fingers=contact_fingers,
                                                 friction_coefficient=friction_coefficient,
                                                 obj_dof=obj_dof,
                                                 obj_ori_rep=obj_ori_rep, obj_joint_dim=1,
                                                 optimize_force=optimize_force, device=device,
                                                 turn=turn, obj_gravity=obj_gravity,
                                                 min_force_dict=min_force_dict, 
                                                 full_dof_goal=full_dof_goal, **kwargs)
        self.friction_coefficient = friction_coefficient

    def _cost(self, xu, start, goal):
        # TODO: check if the addtional term of the smoothness cost and running goal cost is necessary
        state = xu[:, :self.dx]  # state dim = 9
        state = torch.cat((start.reshape(1, self.dx), state), dim=0)  # combine the first time step into it

        smoothness_cost = torch.sum((state[1:, -self.obj_dof:] - state[:-1, -self.obj_dof:]) ** 2)
        upright_cost = 0
        # if not self.full_dof_goal:
        upright_cost = 500 * torch.sum(
            (state[:, -self.obj_dof:-1] + goal[-self.obj_dof:-1]) ** 2)  # the screwdriver should only rotate in z direction
        return smoothness_cost + upright_cost + super()._cost(xu, start, goal)


obj_dof = 3
# config = yaml.safe_load(pathlib.Path(f'../examples/config/{sys.argv[1]}.yaml').read_text())
config = yaml.safe_load(pathlib.Path(f'../examples/config/allegro_screwdriver_csvto_only.yaml').read_text())
config['visualize'] = False

if config['mode'] == 'hardware':
    env = RosAllegroValveTurningEnv(1, control_mode='joint_impedance',
                                    use_cartesian_controller=False,
                                    viewer=True,
                                    steps_per_action=60,
                                    friction_coefficient=1.0,
                                    device=config['sim_device'],
                                    valve=config['object_type'],
                                    video_save_path=img_save_dir,
                                    joint_stiffness=config['kp'],
                                    fingers=config['fingers'],
                                    )
else:
    if not config['visualize']:
        img_save_dir = None

    env = AllegroScrewdriverTurningEnv(1, control_mode='joint_impedance',
                                        use_cartesian_controller=False,
                                        viewer=config['visualize'],
                                        steps_per_action=60,
                                        friction_coefficient=config['friction_coefficient'] * 1.05,
                                        # friction_coefficient=1.0,  # DEBUG ONLY, set the friction very high
                                        device=config['sim_device'],
                                        video_save_path=img_save_dir,
                                        joint_stiffness=config['kp'],
                                        fingers=config['fingers'],
                                        )

sim, gym, viewer = env.get_sim()

state = env.get_state()
# try:
#     while True:
#         start = env.get_state()['q'][:, :-1]
#         env.step(start)
#         print('waiting for you to finish camera adjustment, ctrl-c when done')
#         time.sleep(0.1)
# except KeyboardInterrupt:
#     pass

sim_env = None
ros_copy_node = None
if config['mode'] == 'hardware':
    sim_env = env
    from hardware.hardware_env import HardwareEnv

    env = HardwareEnv(sim_env.default_dof_pos[:, :16], finger_list=['index', 'thumb'], kp=config['kp'])
    env.world_trans = sim_env.world_trans
    env.joint_stiffness = sim_env.joint_stiffness
    env.device = sim_env.device
    env.valve_pose = sim_env.valve_pose
elif config['mode'] == 'hardware_copy':
    from hardware.hardware_env import RosNode

    ros_copy_node = RosNode()

results = {}

# set up the kinematic chain
asset = f'{get_assets_dir()}/xela_models/allegro_hand_right.urdf'
ee_names = {
    'index': 'allegro_hand_hitosashi_finger_finger_0_aftc_base_link',
    'middle': 'allegro_hand_naka_finger_finger_1_aftc_base_link',
    'ring': 'allegro_hand_kusuri_finger_finger_2_aftc_base_link',
    'thumb': 'allegro_hand_oya_finger_3_aftc_base_link',
}
config['ee_names'] = ee_names
config['obj_dof'] = 3

screwdriver_asset = f'{get_assets_dir()}/screwdriver/screwdriver.urdf'

chain = pk.build_chain_from_urdf(open(asset).read())
screwdriver_chain = pk.build_chain_from_urdf(open(screwdriver_asset).read())
frame_indices = [chain.frame_to_idx[ee_names[finger]] for finger in config['fingers']]  # combined chain
frame_indices = torch.tensor(frame_indices)
state2ee_pos = partial(state2ee_pos, fingers=config['fingers'], chain=chain, frame_indices=frame_indices,
                        world_trans=env.world_trans)

forward_kinematics = partial(chain.forward_kinematics,
                                frame_indices=frame_indices)  # full_to= _partial_state = partial(full_to_partial_state, fingers=config['fingers'])
partial_to_full_state = partial(partial_to_full_state, fingers=config['fingers'])

controller = 'csvgd'
goal = - 0.5 * torch.tensor([0, 0, np.pi])
# set up params
params = config.copy()
params.pop('controllers')
params.update(config['controllers'][controller])
params['controller'] = controller
params['valve_goal'] = goal.to(device=params['device'])
params['chain'] = chain.to(device=params['device'])
object_location = torch.tensor([0, 0, 1.205]).to(
    params['device'])  # TODO: confirm if this is the correct location
params['object_location'] = object_location

num_fingers = len(params['fingers'])
state = env.get_state()
start = state['q'].reshape(4 * num_fingers + 4).to(device=params['device'])
if 'csvgd' in params['controller']:
    # index finger is used for stability
    if 'index' in params['fingers']:
        fingers = params['fingers']
    else:
        fingers = ['index'] + params['fingers']

pregrasp_problem = AllegroScrewdriver(
    start=start[:4 * num_fingers + obj_dof],
    goal=params['valve_goal'] * 0,
    T=params['T'],
    chain=params['chain'],
    device=params['device'],
    object_asset_pos=env.table_pose,
    object_location=params['object_location'],
    object_type=params['object_type'],
    world_trans=env.world_trans,
    regrasp_fingers=fingers,
    contact_fingers=[],
    obj_dof=obj_dof,
    obj_joint_dim=1,
    optimize_force=params['optimize_force'],
)
# finger gate index
index_regrasp_problem = AllegroScrewdriver(
    start=start[:4 * num_fingers + obj_dof],
    goal=params['valve_goal'] * 0,
    T=params['T'],
    chain=params['chain'],
    device=params['device'],
    object_asset_pos=env.table_pose,
    object_location=params['object_location'],
    object_type=params['object_type'],
    world_trans=env.world_trans,
    regrasp_fingers=['index'],
    contact_fingers=['middle', 'thumb'],
    obj_dof=obj_dof,
    obj_joint_dim=1,
    optimize_force=params['optimize_force'],
    default_dof_pos=env.default_dof_pos[:, :16]
)
thumb_and_middle_regrasp_problem = AllegroScrewdriver(
    start=start[:4 * num_fingers + obj_dof],
    goal=params['valve_goal'] * 0,
    T=params['T'],
    chain=params['chain'],
    device=params['device'],
    object_asset_pos=env.table_pose,
    object_location=params['object_location'],
    object_type=params['object_type'],
    world_trans=env.world_trans,
    contact_fingers=['index'],
    regrasp_fingers=['middle', 'thumb'],
    obj_dof=obj_dof,
    obj_joint_dim=1,
    optimize_force=params['optimize_force'],
    default_dof_pos=env.default_dof_pos[:, :16]
)
turn_problem = AllegroScrewdriver(
    start=start[:4 * num_fingers + obj_dof],
    goal=params['valve_goal'] * 0,
    T=params['T'],
    chain=params['chain'],
    device=params['device'],
    object_asset_pos=env.table_pose,
    object_location=params['object_location'],
    object_type=params['object_type'],
    world_trans=env.world_trans,
    contact_fingers=['index', 'middle', 'thumb'],
    obj_dof=obj_dof,
    obj_joint_dim=1,
    optimize_force=params['optimize_force'],
    default_dof_pos=env.default_dof_pos[:, :16]
)
contact_mode_dict = {0: 'pregrasp', 2: 'index', 1: 'thumb_middle', 3: 'turn'}
t = params['T']
# with open(data_path / f'constraint_violations.p', 'wb') as f:
#     pickle.dump(constraint_violations_all, f)


Not connected to PVD
Physics Engine: PhysX
Physics Device: cpu
GPU Pipeline: disabled
Using VHACD cache directory '/home/abhinav/.isaacgym/vhacd'
Found existing convex decomposition for mesh '/home/abhinav/Documents/github/isaacgym-arm-envs/isaac_victor_envs/assets/xela_models/mesh/allegro/base_ns.stl'
Found existing convex decomposition for mesh '/home/abhinav/Documents/github/isaacgym-arm-envs/isaac_victor_envs/assets/xela_models/mesh/allegro/link_1.0.stl'
Found existing convex decomposition for mesh '/home/abhinav/Documents/github/isaacgym-arm-envs/isaac_victor_envs/assets/xela_models/mesh/ft_c.stl'


  cache = torch.load(dbpath)


In [3]:
def get_contact_points(name):
    path_base = f'/home/abhinav/Documents/ccai/data/experiments/{name}/train_data'

    all_data = []
    all_x = []
    # all_d2goal = []
    all_traj_data = []
    for collection_idx in range(4):
        path = path_base + f'/{name}_data_{collection_idx}/csvgd'
        for trial_num in range(1, 61):
            # print(path + f'/trial_{trial_num}/trajectory.npz')
            # if 'rand' in name or 'proj' in name or 'diff' in name:
            # try:
                # with open(path + f'/trial_{trial_num}/trajectory.pkl', 'rb') as data:
                #     d = pickle.load(data)
                #     traj = np.stack((d[:-1]), axis=0)
                    # end = d[-1].reshape(1, -1)
                    # end = np.concatenate((end, np.zeros((1, 21))), axis=1)
                    # traj = np.concatenate((traj, end), axis=0)
                # else:
                #     d = np.load(path + f'/trial_{trial_num}/trajectory.npz')
                #     end_state = d['x']

            data = np.load(path + f'/trial_{trial_num}/trajectory.npz')
            last_state = data['x']
            last_state = last_state.reshape(1, -1)
            # Concatenate zeros to the end of the last state
            last_last_state = np.concatenate((last_state, np.zeros((1, 21))), axis=1)
            traj = []

            with open(path + f'/trial_{trial_num}/traj_data.p', 'rb') as f:
                traj_data = pickle.load(f)
                for key in traj_data.keys():
                    if torch.is_tensor(traj_data[key]):
                        traj_data[key] = traj_data[key].cpu().numpy()
                # all_d2goal.append(d2goal)
                # if 'rand' not in name and 'proj' not in name and 'diff' not in name::
                #     traj = traj_data[t]['plans'][:, 0]

                    # end = np.concatenate((end_state, np.zeros((1, 21))), axis=1)
                    # traj = np.concatenate((traj, end), axis=0)
                for t in range(12, 1 - 1, -1):
                    traj.append(traj_data[t]['starts'][:, 0, :])

                traj = np.stack(traj, axis=1)
                last_state = traj[1:, 0]
                last_state = np.concatenate((last_state, last_last_state), axis=0)
                last_state = np.expand_dims(last_state, axis=1)
                traj = np.concatenate((traj, last_state), axis=1)
                    
                traj_data[12]['traj'] = np.expand_dims(traj, axis=1)
                all_traj_data.append(traj_data)  
                all_data.append(traj_data)

            # except:
            #     continue
            
    constraint_violations_all = {
        # 'optimizer_paths': [],
        'traj': [],
        # 'inits': [],
        # 'plans': [],

    }
    # for plans_or_inits in constraint_violations_all.keys():
    #     if plans_or_inits == 'traj':
    #         gen_constraint_data(plans_or_inits, constraint_violations_all[plans_or_inits], path, traj_data=all_traj_data)
    #     else:
    #         gen_constraint_data(plans_or_inits, constraint_violations_all[plans_or_inits], path)

    # Take the list of dicts and turn it into a dict of lists
    all_data = {k: [d[k] for d in all_data] for k in all_data[0]}
    all_data['violation'] = constraint_violations_all

    return all_x, all_data#, all_d2goal

In [4]:
data_exec = {}
t = params['T']

for key, name in [
                ('Contact Point test', 'allegro_high_force_high_eps_pi_6'),
                ]:
    
    print(key)
    data_exec[key] = {}
    all_x, all_data = get_contact_points(name)


    data_exec[key] = {**data_exec[key], **all_data}


Contact Point test


In [5]:
def _convert_robot_to_obj(coords, tf_robot_to_world, tf_world_to_obj):
    coords_world = tf_robot_to_world.transform_points(coords)
    coords_obj = tf_world_to_obj.transform_points(coords_world)
    return coords_obj

def _convert_obj_to_robot(coords, tf_robot_to_world, tf_world_to_obj):

    coords_world = tf_world_to_obj.inverse().transform_points(coords)
    coords_robot = tf_robot_to_world.inverse().transform_points(coords_world)

    return coords_robot

def convert_contact_data_to_obj(contact_points, contact_normals, tf_robot_to_world, tf_world_to_obj):
    contact_points_obj = _convert_robot_to_obj(contact_points, tf_robot_to_world, tf_world_to_obj)
    contact_normals_obj = _convert_robot_to_obj(contact_normals, tf_robot_to_world, tf_world_to_obj)
    return contact_points_obj, contact_normals_obj

def convert_contact_data_to_robot(contact_points_obj, contact_normals_obj, tf_robot_to_world, tf_world_to_obj):
    orig_shape = contact_points_obj.shape
    contact_points_obj = contact_points_obj.reshape(-1, 3)
    contact_normals_obj = contact_normals_obj.reshape(-1, 3)
    contact_points_robot = _convert_obj_to_robot(contact_points_obj, tf_robot_to_world, tf_world_to_obj)
    contact_normals_robot = _convert_obj_to_robot(contact_normals_obj, tf_robot_to_world, tf_world_to_obj)

    contact_points_robot = contact_points_robot.reshape(orig_shape)
    contact_normals_robot = contact_normals_robot.reshape(orig_shape)
    return contact_points_robot, contact_normals_robot

In [23]:
cs = turn_problem.contact_scenes
object_location = object_location.reshape(1, 3).to(device=params['device'])
tf_robot_to_world = env.world_trans.to(device=params['device'])#.inverse().to(device=params['device'])



num_trials = len(data_exec['Contact Point test'][12])
all_trajs = []
for trial_num in range(num_trials):
    all_trajs.append(torch.from_numpy(data_exec['Contact Point test'][12][trial_num]['traj'].squeeze()).to(device=params['device']))

def process_batch(all_trajs):
    all_trajs = torch.stack(all_trajs, dim=0)
    # (num_trials, num_modes, num_timesteps)
    N, C, T, _ = all_trajs.shape
    robot_q = all_trajs[..., :12]

    screwdriver_ori = all_trajs[..., -obj_dof:]
    screwdirver_ori_mat = tf.euler_angles_to_matrix(screwdriver_ori, convention='XYZ').reshape(-1, 3, 3)
    screwdriver_ori_quat = tf.matrix_to_quaternion(screwdirver_ori_mat).reshape(-1, 4)
    tf_obj_to_world = tf.Transform3d(rot=screwdriver_ori_quat, pos=object_location, device=params['device'])
    tf_world_to_obj = tf_obj_to_world.inverse()

    q_b = robot_q.reshape(-1, 4 * 3)
    theta_b = screwdriver_ori.reshape(-1, obj_dof)
    theta_obj_joint = torch.zeros((theta_b.shape[0], 1),
                                    device=theta_b.device)  # add an additional dimension for the cap of the screw driver
    # the cap does not matter for the task, but needs to be included in the state for the model
    theta_b = torch.cat((theta_b, theta_obj_joint), dim=1)
    full_q = partial_to_full_state(q_b)

    print(full_q.shape)

    ret_scene = cs.scene_collision_check(full_q.float(), theta_b.float(),
                                        compute_gradient=True,
                                        compute_hessian=False)

    contact_points_obj, contact_normals_obj = convert_contact_data_to_obj(ret_scene['closest_pt_world'], ret_scene['contact_normal'], tf_robot_to_world, tf_world_to_obj)
    sdf = ret_scene['sdf'].reshape(N, C, T, -1)
    contact_points_obj = contact_points_obj.reshape(N, C, T, -1, 3)
    contact_normals_obj = contact_normals_obj.reshape(N, C, T, -1, 3)
    tfs = tf_world_to_obj.get_matrix().reshape(N, C, T, 4, 4).unsqueeze(-3).repeat(1, 1, 1, 3, 1, 1)
    min_ind = torch.argmin(sdf.abs(), dim=-2)

    min_dist = torch.gather(sdf, 2, min_ind.unsqueeze(-1)).squeeze(-1)
    min_contact_points_obj = torch.gather(contact_points_obj, 2, min_ind.unsqueeze(-2).unsqueeze(-1).expand(-1, -1, -1, -1, 3)).squeeze(-3)
    min_contact_normals_obj = torch.gather(contact_normals_obj, 2, min_ind.unsqueeze(-2).unsqueeze(-1).expand(-1, -1, -1, -1, 3)).squeeze(-3)
    min_tfs = torch.gather(tfs, 2, min_ind.unsqueeze(-2).unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, -1, 4, 4)).squeeze(2)#[:, :, 0]

    return min_contact_points_obj, min_contact_normals_obj, min_tfs

# Loop over all_traj with batch_size 1024
min_contact_points_obj = []
min_contact_normals_obj = []
min_tfs = []
batch_size = 1
for i in tqdm(range(0, len(all_trajs), batch_size)):
    min_contact_points_obj_batch, min_contact_normals_obj_batch, min_tfs_batch = process_batch(all_trajs[i:i+batch_size])
    min_contact_points_obj.append(min_contact_points_obj_batch)
    min_contact_normals_obj.append(min_contact_normals_obj_batch)
    min_tfs.append(min_tfs_batch)

min_contact_points_obj = torch.cat(min_contact_points_obj, dim=0)
min_contact_normals_obj = torch.cat(min_contact_normals_obj, dim=0)
min_tfs = torch.cat(min_tfs, dim=0)


# Group the contact points and normals by contact mode
contact_data_obj_by_mode = defaultdict(list)
for trial_num in range(N):
    contact_states = data_exec['Contact Point test'][12][trial_num]['contact_state']
    for contact_state_idx in range(C):
        contact_state_idx_tuple = tuple(contact_states[contact_state_idx].tolist())
        contact_data_obj_by_mode[contact_state_idx_tuple].append(
            (min_contact_points_obj[trial_num, contact_state_idx], 
             min_contact_normals_obj[trial_num, contact_state_idx],
             min_tfs[trial_num, contact_state_idx])
        )

  0%|          | 0/240 [00:00<?, ?it/s]

torch.Size([91, 16])


OutOfMemoryError: CUDA out of memory. Tried to allocate 14.00 MiB. GPU 1 has a total capacity of 23.65 GiB of which 10.31 MiB is free. Process 627143 has 2.42 GiB memory in use. Process 654772 has 384.00 MiB memory in use. Including non-PyTorch memory, this process has 20.80 GiB memory in use. Of the allocated memory 20.32 GiB is allocated by PyTorch, and 34.26 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [15]:
mode = (1.0, 1.0, 1.0)

pts_th_m = torch.stack([i[0] for i in contact_data_obj_by_mode[mode]], dim=1)#.reshape(-1, 3)
orig_shape = pts_th_m.shape
pts_th_m = pts_th_m.reshape(-1, 3)
normals_th_m = torch.stack([i[1] for i in contact_data_obj_by_mode[mode]], dim=1).reshape(-1, 3)
tfs_th_m = torch.stack([i[2] for i in contact_data_obj_by_mode[mode]], dim=1).reshape(-1, 4, 4)#[0].unsqueeze(0).repeat(pts_th_m.shape[0], 1, 1)


# tfs_th_m = torch.tensor([
#     [1., 0, 0, 0],
#     [0, 1., 0, 0],
#     [0, 0, 1., -1.205],
#     [0, 0, 0, 1]
# ]).unsqueeze(0).repeat(normals_th_m.shape[0], 1, 1).to(device=params['device'])
tfs_th_m = tf.Transform3d(matrix=tfs_th_m, device=params['device'])

contact_points_obj = pts_th_m
contact_normals_obj = normals_th_m
contact_points_obj = contact_points_obj.reshape(-1, 3)
contact_normals_obj = contact_normals_obj.reshape(-1, 3)
contact_points_robot = _convert_obj_to_robot(contact_points_obj.unsqueeze(1), tf_robot_to_world, tfs_th_m)
contact_normals_robot = _convert_obj_to_robot(contact_normals_obj.unsqueeze(1), tf_robot_to_world, tfs_th_m)
pts_th_m_rob = contact_points_robot.reshape(orig_shape).squeeze()
normals_th_m_rob = contact_normals_robot.reshape(orig_shape).squeeze()


# pts_th_m_rob, normals_th_m_rob = convert_contact_data_to_robot(pts_th_m, normals_th_m, tf_robot_to_world, tf_world_to_obj)
colors = torch.tensor([
    [1.0, 0.0, 0.0],
    [0.0, 1.0, 0.0],
    [0.0, 0.0, 1.0],
], device=params['device'])
colors = colors.unsqueeze(1).expand(-1, pts_th_m_rob.shape[1], -1)

pcd_points = torch.cat((pts_th_m_rob, colors), dim=-1).cpu().numpy().reshape(-1, 6)
normals = normals_th_m_rob.cpu().numpy().reshape(-1, 3)

pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(pcd_points[:, :3])
pcd.colors = o3d.utility.Vector3dVector(pcd_points[:, 3:])
pcd.normals = o3d.utility.Vector3dVector(normals)

traj_for_viz = torch.cat((robot_q, theta_b.reshape(*robot_q.shape[:-1], -1)), dim=-1)[0:1, 0, 0]

visualize_trajectory(traj_for_viz, cs, 'images', config['fingers'], obj_dof+1, pcd=pcd)


NameError: name 'contact_data_obj_by_mode' is not defined