In [8]:
import torch
from torch import nn
from tidybot2.utils import get_policy, rmat_to_quat, rot6d_to_rmat, get_cfg
from tidybot2.policy_wrapper import PolicyWrapper
import numpy as np
from actpp.actpp_policy import ACTPolicy
from omegaconf import OmegaConf

In [2]:
diffusion_pw_cfg = {
    "n_obs": 2,
    "n_acts": 8,
    "d_pos": 6,
    "d_rot": 6
}

actpp_pw_cfg = {
    "n_obs": 1,
    "n_acts": 8,
    "d_pos": 6,
    "d_rot": 6
}

def test_policy(policy, dummy_normalizer, pw_cfg):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    policy = policy.to(device)
    policy.eval()
    for param in policy.parameters():
        param.requires_grad = False
    if dummy_normalizer:
        policy.set_normalizer(None, dummy_normalizer)
    pw = PolicyWrapper(policy, device=device, **pw_cfg)

    # This is for the mobile base, fill in with your shapes
    obs = {
        'base_pose': np.zeros(3, dtype=np.float32),
        'arm_pos': np.zeros(3, dtype=np.float32),
        'arm_rot': np.zeros(6, dtype=np.float32),
        'arm_rot_wrt_start': np.zeros(6, dtype=np.float32),
        'gripper_pos': np.zeros(1, dtype=np.float32),
        'base_image': np.zeros((84, 84, 3), dtype=np.uint8),
        'wrist_image': np.zeros((84, 84, 3), dtype=np.uint8),
    }

    action = pw.get_action(obs)
    print(action)

    base_pose = action[:3]
    arm_pos = action[3:6]
    arm_6d = torch.from_numpy(action[6:12])
    arm_quat = rmat_to_quat(rot6d_to_rmat(arm_6d))
    gripper_pos = action[12:13]

    action_dict = {
        'base_pose': base_pose,
        'arm_pos': arm_pos,
        'arm_quat': arm_quat,
        'gripper_pos': gripper_pos
    }

    return action_dict

# Diffusion Policy Example

In [7]:
ckpt_path = '/juno/u/aadityap/universal_manipulation_interface/data/test/checkpoints/epoch_8450.ckpt'

In [10]:
d_cfg = get_cfg(ckpt_path)
#dump this to a yaml file
OmegaConf.save(d_cfg, 'config.yaml')

In [10]:
diffusion_policy = get_policy(ckpt_path)

You are using the CTM base workspace! Ensure that you don't wish to use the normal DP base workspace.


using obs modality: low_dim with keys: ['base_pose', 'arm_rot_wrt_start', 'gripper_pos', 'arm_pos', 'arm_rot']
using obs modality: rgb with keys: ['wrist_image', 'base_image']
using obs modality: depth with keys: []
using obs modality: scan with keys: []




Diffusion params: 6.721550e+07
Vision params: 2.239418e+07
_output_dir
global_step
epoch


In [12]:
test_policy(diffusion_policy, False, diffusion_pw_cfg)

[-0.00119794  0.00627107  0.04214258 -0.00265307  0.00286261 -0.00144666
  0.96996975  0.00254987 -0.02307275 -0.03843752  0.99904     0.00587935
  0.8120817 ]


{'base_pose': array([-0.00119794,  0.00627107,  0.04214258], dtype=float32),
 'arm_pos': array([-0.00265307,  0.00286261, -0.00144666], dtype=float32),
 'arm_quat': array([-0.00249938, -0.01189423, -0.00128439,  0.99992531]),
 'gripper_pos': array([0.8120817], dtype=float32)}

# ACT++ Example

In [3]:
cfg_path = 'actpp/example_actpp_cfg.yaml'
cfg = OmegaConf.load(cfg_path)
print(cfg.policy)

{'action_dim': 13, 'backbone': 'resnet18', 'frozen_backbone': False, 'dilation': False, 'position_embedding': 'sine', 'camera_names': ['base_image', 'wrist_image'], 'no_encoder': False, 'enc_layers': 4, 'dec_layers': 7, 'dim_feedforward': 2048, 'hidden_dim': 256, 'dropout': 0.1, 'nheads': 8, 'num_queries': 400, 'pre_norm': False, 'vq': False, 'vq_class': 12, 'vq_dim': 64, 'masks': False}


In [4]:
actpp_policy = ACTPolicy(cfg)

Assuming state dim is 19 because this is the mobile base repo!




Use VQ: False, 12, 64
Using Camera Names ['base_image', 'wrist_image']
number of parameters: 55.23M
KL Weight 0.1


In [6]:
_ = test_policy(actpp_policy, True, actpp_pw_cfg)

[-1.291825   -0.72286344  0.05941978  1.4747369   0.5093037  -0.57023007
 -0.56569695  0.510679   -0.0350954  -0.7592864  -0.5361282   0.08769123
  0.37533844]


  dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
