In [None]:
import torch
import dill

def print_key_tree(obj, prefix="", max_depth=3, current_depth=0):
    """
    递归打印嵌套字典/对象的键结构树
    """
    if current_depth > max_depth:
        print(f"{prefix}...")
        return
        
    if isinstance(obj, dict):
        for key, value in obj.items():
            type_name = type(value).__name__
            if hasattr(value, '__len__') and not isinstance(value, (str, bytes)):
                if hasattr(value, 'shape'):  # Tensor
                    print(f"{prefix}{key}: {type_name} (shape: {value.shape})")
                else:  # 其他容器类型
                    print(f"{prefix}{key}: {type_name} (len: {len(value)})")
                    if current_depth < max_depth and isinstance(value, dict):
                        print_key_tree(value, prefix + "  ", max_depth, current_depth + 1)
                    elif current_depth < max_depth and hasattr(value, '__getitem__'):
                        # 尝试查看前几个键
                        try:
                            if len(value) > 0:
                                print(f"{prefix}  Sample keys: {list(value.keys())[:5]}")
                        except:
                            pass
            else:
                print(f"{prefix}{key}: {type_name}")
                # 如果是复合对象，递归查看
                if current_depth < max_depth and hasattr(value, '__dict__'):
                    print_key_tree(value.__dict__, prefix + "  ", max_depth, current_depth + 1)
    
    elif hasattr(obj, '__dict__'):
        print_key_tree(obj.__dict__, prefix, max_depth, current_depth)

# 加载checkpoint
path = "/data/yixiang/workspace/SRCB-DexGraspVLA-Project/checkpoint/mixed_4_data/epoch=120.ckpt"
payload = torch.load(path, pickle_module=dill, weights_only=False)

print("=== Checkpoint Structure ===")
print_key_tree(payload)

# 详细查看state_dict
if 'state_dicts' in payload:
    print("\n=== State Dict Structure ===")
    print(f"Total parameters: {len(payload['state_dicts']['model'])}")
    
    # 按模块分组显示
    modules = {}
    for key in payload['state_dicts']['model'].keys():
        module_name = key.split('.')[0] if '.' in key else key
        if module_name not in modules:
            modules[module_name] = []
        modules[module_name].append(key)
    
    for module_name, keys in modules.items():
        print(f"{module_name}: {len(keys)} parameters")
        # 显示前3个参数作为示例
        for key in keys[:3]:
            tensor = payload['state_dicts']['model'][key]
            if hasattr(tensor, 'shape'):
                print(f"  {key}: {tensor.shape}")
        if len(keys) > 3:
            print(f"  ... and {len(keys)-3} more")

# 查看其他元信息
print("\n=== Other Metadata ===")
for key in payload.keys():
    if key != 'state_dicts':
        value = payload[key]
        if isinstance(value, (str, int, float, bool)) or value is None:
            print(f"{key}: {value}")
        else:
            print(f"{key}: {type(value).__name__} ({str(value)[:100]})")


In [2]:
import os
import sys
parent_dir = os.path.dirname(os.getcwd())
sys.path.insert(0, parent_dir)
from typing import Dict, Callable
import numpy as np
import torch.nn.functional as F
from torch import nn
import zarr
import yaml
from omegaconf import OmegaConf
from datetime import datetime
import torch
import dill
import hydra
from dexgraspvla.controller.policy.dexgraspvla_controller import DexGraspVLAController
from scripts.utils.profile_utils import profile_class

def create_dummy_payload(model, action_shape):
    """
    创建包含完整 model 参数（包括 normalizer）的 payload
    """
    # 先获取原始 model 结构，含 normalizer keys
    model_state_dict = {}
    for name, param in model.named_parameters():
        model_state_dict[name] = torch.randn_like(param)

    # 把 normalizer 的模拟参数插入进去（如果它们本来不在 named_parameters 里）
    # 这取决于你 LinearNormalizer 是否继承了 nn.Module 并注册了参数
    from torch import nn

    if hasattr(action_shape, '__getitem__'):
        action_dim = action_shape[-1]
    else:
        action_dim = int(action_shape)

    fake_min = torch.full((action_dim,), -1.0)
    fake_max = torch.full((action_dim,), 1.0)
    fake_mean = torch.zeros((action_dim,))
    fake_std = torch.ones((action_dim,))
    scale = torch.full((action_dim,), 2.0)
    offset = torch.full((action_dim,), -1.0)

    # 注意这里的 keys 必须与你在 model.state_dict() 看到的一模一样！
    normalizer_dummy_params = {
        'normalizer.params_dict.action.offset': offset,
        'normalizer.params_dict.action.scale': scale,
        'normalizer.params_dict.action.input_stats.max': fake_max,
        'normalizer.params_dict.action.input_stats.mean': fake_mean,
        'normalizer.params_dict.action.input_stats.min': fake_min,
        'normalizer.params_dict.action.input_stats.std': fake_std,

        'normalizer.params_dict.right_state.offset': offset,
        'normalizer.params_dict.right_state.scale': scale,
        'normalizer.params_dict.right_state.input_stats.max': fake_max,
        'normalizer.params_dict.right_state.input_stats.mean': fake_mean,
        'normalizer.params_dict.right_state.input_stats.min': fake_min,
        'normalizer.params_dict.right_state.input_stats.std': fake_std,
    }

    # 合并进 model_state_dict
    model_state_dict.update(normalizer_dummy_params)

    payload = {
        "state_dicts": {
            "model": model_state_dict
        }
    }

    return payload

def load_config(main_config_path, task_config_path):
    """
    Load main configuration file and its referenced configuration files

    Args:
        config_path: Configuration file root directory
        config_name: Main configuration file name (without .yaml)
    """
    def now_resolver(pattern: str):
        """Handle ${now:} time formatting"""
        return datetime.now().strftime(pattern)

    OmegaConf.register_new_resolver("now", now_resolver, replace=True)
    OmegaConf.register_new_resolver("eval", eval, replace=True)

    # Create default configuration
    default_cfg = OmegaConf.create({
        "hydra": {
            "job": {
                "num": 0,  # Provide default value
                "override_dirname": "${name}"
            }
        }
    })

    # Load main configuration file
    cfg = OmegaConf.load(main_config_path)

    # Merge default configuration
    cfg = OmegaConf.merge(default_cfg, cfg)
    task_cfg = OmegaConf.load(task_config_path)
    cfg["task"] = task_cfg

    # Parse all variable references
    OmegaConf.resolve(cfg)

    return cfg
def load_zarr_data(zarr_path):
    """加载Zarr数据集"""
    try:
        f = zarr.open(zarr_path)
        print(list(f['data'].keys()))
        rgbm_data = f['data/rgbm'][:]
        action = f['data/action'][:]
        right_cam_img = f['data/right_cam_img'][:]
        right_state = f['data/right_state'][:]
        episode_ends = np.insert(f['meta/episode_ends'][:], 0, 0)
        return rgbm_data, action, right_cam_img, right_state, episode_ends
    except Exception as e:
        raise RuntimeError(f"Error loading Zarr data: {str(e)}")

def update_array(existing_array, new_array):
    # Create new array to store updated data
    updated_array = np.empty_like(existing_array)

    # Move the previous array's last item to the second position
    for i in range(0, existing_array.shape[0]):
        if i < existing_array.shape[0]-1:
            updated_array[i, ...] = existing_array[i+1, ...]
        else:
            # Add new array to the last position of the first dimension
            updated_array[i, ...] = new_array

    return updated_array

def dict_apply(
        x: Dict[str, torch.Tensor], 
        func: Callable[[torch.Tensor], torch.Tensor]
        ) -> Dict[str, torch.Tensor]:
    result = dict()
    for key, value in x.items():
        if isinstance(value, dict):
            result[key] = dict_apply(value, func)
        else:
            result[key] = func(value)
    return result

def load_inference_config(config_path):
    """Load system configuration from YAML file"""
    with open(config_path, 'r') as f:
        return yaml.safe_load(f)

class VLAController:
    def __init__(self, config, vla_cfg, model:DexGraspVLAController, payload = {}) -> None:
        self.config = config
        self.cfg = vla_cfg
        resolution = self.config['cameras']['right_hand_cameras']['resolution']
        self.right_first_color_image_buffer = np.zeros((self.cfg.n_obs_steps, resolution[1], resolution[0], 3))
        self.third_color_image_buffer = np.zeros((self.cfg.n_obs_steps, resolution[1], resolution[0], 4))
        self.state_buffer = np.zeros((self.cfg.n_obs_steps, 13))
        self.time_step = 0
        self.device = 'cuda:0'
        self.model = model
        
        if payload:
            model_state = payload.get('state_dicts', {}).get('model')
            if model_state:
                self.model.load_state_dict(model_state)
            else:
                print("Model state dict not found in payload.")
        else:
            print("Payload is None or empty.")

        self.model.to(self.device)
        self.model.eval()

    def predict_action(self, state, right_first_color_image, third_color_image_with_mask):
        obs = self.get_obs(state, right_first_color_image, third_color_image_with_mask)
        attn_map_output_path = None
        self.time_step += 1
        obs_dict_np = self.process_obs(env_obs=obs, shape_meta=self.cfg.task.shape_meta)
        obs_dict = dict_apply(obs_dict_np, 
                lambda x: torch.from_numpy(x).unsqueeze(0).to(self.device))
        # print("test==============:2")
        with torch.no_grad():
            action_pred = self.model.predict_action(obs_dict, attn_map_output_path)
            # print("test==============:3")
            action = action_pred[0].detach().to('cpu').numpy()
        # print("test==============:", action)
        return action
    
    def get_obs(self, state, right_first_color_image, third_color_image_with_mask):
        # self.show_and_save_image_with_mask(self.third_color_image, mask, "/data/dingzher/DexGrasp_Demo/SRCB-DexVLA/temp_1")
        self.right_first_color_image_buffer = update_array(
            self.right_first_color_image_buffer, 
            right_first_color_image
        )
        # self.show_and_save_image_with_mask(self.right_first_color_image, mask, "/data/dingzher/DexGrasp_Demo/SRCB-DexVLA/temp_1")
        self.third_color_image_buffer = update_array(
            self.third_color_image_buffer, 
            third_color_image_with_mask
        )
        self.state_buffer = update_array(self.state_buffer, state)
        print(f"state_value: {state}")
        # input("check the input")
        obs = {"right_cam_img": self.right_first_color_image_buffer, "rgbm": self.third_color_image_buffer, "right_state": self.state_buffer}
        return obs
    
    def process_obs(self, env_obs, shape_meta):
        """Get observation dictionary, using torch for image processing"""
        obs_dict_np = {}
        obs_shape_meta = shape_meta['obs']
        
        for key, attr in obs_shape_meta.items():
            type = attr.get('type', 'low_dim')
            shape = attr.get('shape')

            if type == 'rgb':
                imgs_in = env_obs[key]
                rgb = torch.from_numpy(imgs_in[..., :3]).float()  # [T, H, W, 3]
                rgb = rgb.permute(0, 3, 1, 2)  # [T, 3, H, W]
                # Scale image
                rgb = F.interpolate(
                    rgb / 255.0,
                    size=(shape[1], shape[2]),
                    mode='bilinear',
                    align_corners=False
                )
                obs_dict_np[key] = rgb.numpy()

            elif type == 'rgbm':  # Process mask image
                imgs_in = env_obs[key]
                # Convert to torch tensor and adjust dimensions
                rgb = torch.from_numpy(imgs_in[..., :3]).float()  # [T, H, W, 3]
                mask = torch.from_numpy(imgs_in[..., 3:]).float()
                # Adjust channel order
                rgb = rgb.permute(0, 3, 1, 2)  # [T, 3, H, W]
                # Scale RGB
                rgb = F.interpolate(
                    rgb / 255.0,
                    size=(shape[1], shape[2]),  # Use the size specified in shape_meta
                    mode='bilinear',
                    align_corners=False
                )
                # Process mask
                mask = mask.permute(0, 3, 1, 2)  # [T, 1, H, W]
                mask = F.interpolate(
                    mask,
                    size=(shape[1], shape[2]),
                    mode='nearest'
                )
                mask = (mask > 0.5).float()
                # Combine RGB and mask
                out_imgs = torch.cat([rgb, mask], dim=1)  # [T, 4, H, W]
                obs_dict_np[key] = out_imgs.numpy()

            elif type == 'low_dim':
                obs_dict_np[key] = env_obs[key].astype(np.float32)
        
        return obs_dict_np


In [4]:
path = "/data/yixiang/workspace/SRCB-DexGraspVLA-Project/checkpoint/mixed_4_data/epoch=120.ckpt"
payload = torch.load(path, pickle_module=dill, weights_only=False)


In [3]:

data_path = '/data/dingzher/DexGrasp_Demo/SRCB-DexVLA/zarr_data_transfer/output_data_20250610_single_bowl.zarr/'
rgbm_data, action, right_cam_img, right_state, episode_ends = load_zarr_data(data_path)


['action', 'rgbm', 'right_cam_img', 'right_state']


In [13]:
print(rgbm_data.shape)
print(action.shape)
print(right_cam_img.shape)
print(right_state.shape)

(8908, 480, 640, 4)
(64, 13)
(8908, 480, 640, 3)
(8908, 13)


In [4]:

current_dir = os.getcwd()  # 当前工作目录
main_config_path = os.path.join(current_dir, '../src/dexgraspvla/controller', 'config', 'train_dexgraspvla_controller_workspace.yaml')
task_config_path = os.path.join(current_dir, '../src/dexgraspvla/controller', 'config', 'task', 'grasp.yaml')

vla_cfg = load_config(
    main_config_path=main_config_path,
    task_config_path=task_config_path
)
print(vla_cfg)

inf_cfg = load_inference_config('/data/shiqi/SRCB-DexGraspVLA-Project/config.yaml')
print(inf_cfg)


{'hydra': {'job': {'num': 0, 'override_dirname': 'train_dexgraspvla_controller'}, 'run': {'dir': 'data/outputs/2025.08.05/15.35_train_dexgraspvla_controller_grasp'}, 'sweep': {'dir': 'data/outputs/2025.08.05/15.35_train_dexgraspvla_controller_grasp', 'subdir': 0}}, 'defaults': ['_self_', {'task': 'grasp'}], 'name': 'train_dexgraspvla_controller', '_target_': 'dexgraspvla.controller.workspace.train_dexgraspvla_controller_workspace.TrainDexGraspVLAControllerWorkspace', 'task_name': 'grasp', 'shape_meta': {'obs': {'right_cam_img': {'shape': [3, 518, 518], 'type': 'rgb', 'horizon': 1}, 'rgbm': {'shape': [4, 518, 518], 'type': 'rgbm', 'horizon': 1}, 'right_state': {'shape': [13], 'type': 'low_dim', 'horizon': 1}}, 'action': {'shape': [13], 'horizon': 64}}, 'exp_name': 'default', 'n_action_steps': 64, 'n_obs_steps': 1, 'n_latency_steps': 0, 'dataset_obs_steps': 1, 'past_action_visible': False, 'keypoint_visible_rate': 1.0, 'obs_as_cond': True, 'policy': {'_target_': 'dexgraspvla.controller.p

In [5]:
model = hydra.utils.instantiate(vla_cfg.policy)

Using cache found in /home/samsung/.cache/torch/hub/facebookresearch_dinov2_main
Using cache found in /home/samsung/.cache/torch/hub/facebookresearch_dinov2_main


In [6]:
dummy_payload = create_dummy_payload(model, vla_cfg['shape_meta']["action"]["shape"])

In [7]:
vla_controller=  VLAController(inf_cfg,vla_cfg, model, dummy_payload)

In [None]:
rgbm_data[...].shape

In [12]:
action = vla_controller.predict_action(right_state[0:1,...],right_cam_img[0:1,...],rgbm_data[0:1,...])

state_value: [[-0.41853124 -0.33612895  0.56479067 -0.5503908   0.34763214  0.19313775
  -0.08707141  0.512       0.61        0.638       0.718       0.994
  -0.526     ]]

=== Profile results for predict_action ===
         56467 function calls (46336 primitive calls) in 0.343 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.343    0.343 /data/shiqi/DexGraspVLA/src/dexgraspvla/controller/policy/dexgraspvla_controller.py:162(predict_action)
  5037/16    0.003    0.000    0.336    0.021 /data/shiqi/DexGraspVLA/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1732(_wrapped_call_impl)
  5037/16    0.006    0.000    0.336    0.021 /data/shiqi/DexGraspVLA/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1740(_call_impl)
        1    0.007    0.007    0.303    0.303 /data/shiqi/DexGraspVLA/src/dexgraspvla/controller/policy/dexgraspvla_controller.py:112(solve_ode)
       

In [None]:
action.shape