In [94]:
#!/usr/bin/env python
import os
import sys
import time
import torch
import dill
import numpy as np
import collections
import imageio
import cv2
from scipy.spatial import ConvexHull, Delaunay
from scipy.spatial.transform import Rotation as R

# -------------------------------
# Deoxys and Franka Interface Imports
# -------------------------------
sys.path.append("/home/franka_deoxys/deoxys_control/deoxys")
from deoxys import config_root
from deoxys.franka_interface import FrankaInterface
from deoxys.utils import YamlConfig
from deoxys.utils.config_utils import robot_config_parse_args
from deoxys.experimental.motion_utils import follow_joint_traj, reset_joints_to

sys.path.append("/home/franka_deoxys/deoxys_vision")
from deoxys_vision.networking.camera_redis_interface import CameraRedisSubInterface
from deoxys_vision.utils.camera_utils import assert_camera_ref_convention, get_camera_info

# -------------------------------
# Diffusion Policy Imports
# -------------------------------
sys.path.append("/home/franka_deoxys/diffusion_policy")
# from diffusion_policy.workspace.train_diffusion_unet_hybrid_workspace import TrainDiffusionUnetHybridWorkspace
from diffusion_policy.workspace.train_diffusion_unet_hybrid_workspace_vision_emph import TrainDiffusionUnetHybridWorkspace


from diffusion_policy.model.common.rotation_transformer import RotationTransformer
from diffusion_policy.common.pytorch_util import dict_apply
from util_eval import RobotStateRawObsDictGenerator, FrameStackForTrans

# -------------------------------
# Set Device
# -------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [2]:
# -------------------------------
# Load Robot and Controller Configurations
# -------------------------------
args = robot_config_parse_args()
robot_interface = FrankaInterface(os.path.join(config_root, args.interface_cfg))
controller_cfg = YamlConfig(os.path.join(config_root, args.controller_cfg)).as_easydict()
controller_type = args.controller_type

# -------------------------------
# Raw Observation Generator (for images, etc.)
# -------------------------------
raw_obs_dict_generator = RobotStateRawObsDictGenerator()

def set_gripper(open=True):
    d = -1.0 if open else 1.0
    action = np.array([0., 0., 0., 0., 0., 0., d])
    robot_interface.control(controller_type=controller_type,
                              action=action,
                              controller_cfg=controller_cfg)

In [95]:




# -------------------------------
# Setup Camera Interfaces
# -------------------------------
camera_ids = [0,1]
cr_interfaces = {}
use_depth = False
for cam_id in camera_ids:
    camera_ref = f"rs_{cam_id}"
    assert_camera_ref_convention(camera_ref)
    cam_info = get_camera_info(camera_ref)
    print("Camera Info:", cam_info)
    cr_int = CameraRedisSubInterface(camera_info=cam_info, use_depth=use_depth, redis_host='127.0.0.1')
    cr_int.start()
    cr_interfaces[cam_id] = cr_int

def get_imgs(use_depth=False):
    data = {}
    for cam_id in camera_ids:
        imgs = cr_interfaces[cam_id].get_img()
        color_img = imgs["color"][..., ::-1]  # Convert BGR to RGB
        if cam_id == 1:
            color_img = cv2.resize(color_img, (320,240))
        
        data[f"camera_{cam_id}_color"] = color_img
    return data

def get_current_obs():
    """
    Gather the latest observation dictionary.
    We still use raw_obs_dict_generator for images and other signals,
    but the current end-effector pose will be obtained directly via robot_interface.last_q.
    """
    last_state = robot_interface._state_buffer[-1]
    last_gripper_state = robot_interface._gripper_state_buffer[-1]
    obs = raw_obs_dict_generator.get_raw_obs_dict({
        "last_state": last_state,
        "last_gripper_state": last_gripper_state
    })
    data = get_imgs(use_depth=False)
    eye_img = data['camera_1_color']
    agent_view = data['camera_0_color']
    obs['eye_in_hand_rgb'] = eye_img.transpose(2,0,1).astype(np.float32)/255.0
    obs['agent_view'] = agent_view
    return obs


Camera Info: {'camera_id': 0, 'camera_type': 'rs', 'camera_name': 'camera_rs_0'}
CameraRedisSubInterface:: {'camera_id': 0, 'camera_type': 'rs', 'camera_name': 'camera_rs_0'} True False
Camera Info: {'camera_id': 1, 'camera_type': 'rs', 'camera_name': 'camera_rs_1'}
CameraRedisSubInterface:: {'camera_id': 1, 'camera_type': 'rs', 'camera_name': 'camera_rs_1'} True False


In [207]:
# -------------------------------
# Conversion Function: Convert ee_states (T,16) -> (T,7)
# -------------------------------
def convert_rawstates_16_to_7d(raw_states):
    """
    Convert raw_states of shape (T,16) (flattened 4x4 matrices) into an array of shape (T,7),
    where each row is [x, y, z, qx, qy, qz, qw].

    Assumes:
      - The first 9 elements (indices 0-8) form the rotation matrix (row-major).
      - The translation is stored in indices 12, 13, and 14.
    """
    T = raw_states.shape[0]
    eef_poses_7d = np.zeros((T, 7), dtype=np.float32)
    for t in range(T):
        row = raw_states[t]
        tx, ty, tz = row[12], row[13], row[14]
        rot_elements = row[0:9]
        rot_mat = rot_elements.reshape(3, 3)
        quat = R.from_matrix(rot_mat).as_quat()  # [qx, qy, qz, qw]
        eef_poses_7d[t, :3] = [tx, ty, tz]
        eef_poses_7d[t, 3:] = quat
    return eef_poses_7d

# -------------------------------
# Observation Function: Get and Augment Raw Observations
# -------------------------------

# def get_imgs(use_depth=False):
#     # Dummy implementation: replace with your actual camera code.
#     dummy_img = np.zeros((240,320,3), dtype=np.uint8)
#     return {"eye_in_hand_rgb": dummy_img.transpose(2,0,1).astype(np.float32)/255.0,"agent_view": dummy_img.transpose(2,0,1)}

def get_current_obs():
    """
    Obtain the latest observation dictionary.
    Uses raw_obs_dict_generator for images and other signals,
    and augments the observation with a converted "ee_states_7d" (from raw "ee_states").
    """
    last_state = robot_interface._state_buffer[-1]
    last_gripper_state = robot_interface._gripper_state_buffer[-1]
    obs = raw_obs_dict_generator.get_raw_obs_dict({
        "last_state": last_state,
        "last_gripper_state": last_gripper_state
    })
    data = get_imgs(use_depth=False)
    eye_img = data['camera_1_color']
    agent_view = data['camera_0_color']
    obs['eye_in_hand_rgb_or']= eye_img.copy()
    obs['eye_in_hand_rgb'] = eye_img.transpose(2,0,1).astype(np.float32)/255.0
    obs['agent_view'] = agent_view
    
    # Keep raw "ee_states" for the policy; also add "ee_states_7d" for safe-set checking.
    if "ee_states" in obs:
        ee = obs["ee_states"]
        # If ee is a 1D array with 16 elements, reshape to (1,16)
        if ee.ndim == 1 and ee.shape[0] == 16:
            ee = ee.reshape(1, -1)
        # Only add conversion if the last dimension is 16.
        if ee.ndim >= 2 and ee.shape[-1] == 16:
            obs["ee_states_7d"] = convert_rawstates_16_to_7d(ee)
    return obs

# -------------------------------
# Frame Stacker
# -------------------------------
class FrameStackForTrans:
    def __init__(self, num_frames):
        self.num_frames = num_frames
        self.obs_history = {}

    def reset(self, init_obs):
        self.obs_history = {}
        for k in init_obs:
            self.obs_history[k] = collections.deque([init_obs[k][None] for _ in range(self.num_frames)],
                                                    maxlen=self.num_frames)
        return {k: np.concatenate(self.obs_history[k], axis=0) for k in self.obs_history}

    def add_new_obs(self, new_obs):
        for k in new_obs:
            if 'timesteps' in k or 'actions' in k:
                continue
            self.obs_history[k].append(new_obs[k][None])
        return {k: np.concatenate(self.obs_history[k], axis=0) for k in self.obs_history}

def fix_image_for_opencv(img):
    img = np.asanyarray(img)
    if img.ndim == 3 and img.shape[0] in [1,3,4] and img.shape[1]>10 and img.shape[2]>10:
        img = np.transpose(img, (1,2,0))
    if img.ndim == 2:
        img = np.stack([img]*3, axis=-1)
    if img.dtype != np.uint8:
        maxval = img.max()
        if maxval <= 1.0:
            img = (img*255.0).clip(0,255).astype(np.uint8)
        else:
            img = img.clip(0,255).astype(np.uint8)
    return np.ascontiguousarray(img)

# -------------------------------
# Safe Set Loading & Utility Functions (from sim code)
# -------------------------------
def load_safe_set(file_path):
    data = np.load(file_path)
    safe_set_positions = data["safe_set_positions"]  # (N,3)
    if "safe_set_orientations" in data:
        safe_set_orientations = data["safe_set_orientations"]
    else:
        N = safe_set_positions.shape[0]
        safe_set_orientations = np.tile(np.array([1,0,0]), (N,1))
    hull_equations = data["hull_equations"] if "hull_equations" in data else None
    hull_vertices = data["hull_vertices"] if "hull_vertices" in data else None
    return safe_set_positions, safe_set_orientations, hull_equations, hull_vertices

def normalize_quaternion(q):
    norm = np.linalg.norm(q)
    if norm < 1e-6:
        return np.array([0,0,0,1], dtype=np.float32)
    return q/norm

def pose7d_to_6d(pose7d):
    pos = pose7d[:3]
    quat = normalize_quaternion(pose7d[3:7])
    rotvec = R.from_quat(quat).as_rotvec()
    return np.hstack([pos, rotvec])

def pose6d_to_7d(pose6d):
    pos = pose6d[:3]
    rotvec = pose6d[3:6]
    quat = R.from_rotvec(rotvec).as_quat()
    return np.hstack([pos, quat])

def apply_6d_global_delta_once(pose7d, delta6d):
    pose6d = pose7d_to_6d(pose7d)
    new6d = pose6d + delta6d
    return pose6d_to_7d(new6d)

def is_pose_in_safe_set(query_pose_7d, safe_set_positions, safe_set_orientations,
                        tol_pos=0.1, tol_ori_cos=0.9, delaunay=None, centroid=None, avg_forward=None):
    query_pos = query_pose_7d[:3]
    if delaunay is not None:
        pos_inside = (delaunay.find_simplex(query_pos) >= 0)
    else:
        try:
            tri = Delaunay(safe_set_positions)
            pos_inside = (tri.find_simplex(query_pos) >= 0)
        except Exception:
            if centroid is None:
                centroid = np.mean(safe_set_positions, axis=0)
            pos_inside = (np.linalg.norm(query_pos - centroid) < tol_pos)
    query_quat = normalize_quaternion(query_pose_7d[3:7])
    query_forward = R.from_quat(query_quat).apply([1,0,0])
    if avg_forward is None:
        avg_forward = np.mean(safe_set_orientations, axis=0)
        avg_forward /= np.linalg.norm(avg_forward)
    cos_sim = np.dot(query_forward, avg_forward)
    ori_inside = (cos_sim >= tol_ori_cos)
    return (pos_inside and ori_inside), pos_inside, ori_inside

def compute_push_delta(current_pose_7d, safe_set_positions,
                       safety_margin=0.22, push_factor=2.0, min_push=0.1, push_direction=None):
    current_pos = current_pose_7d[:3]
    centroid = np.mean(safe_set_positions, axis=0)
    delta_centroid = centroid - current_pos
    delta_centroid[2] = 0.0
    norm_delta = np.linalg.norm(delta_centroid)
    if norm_delta > safety_margin:
        push_magnitude = (norm_delta - safety_margin) * push_factor
    else:
        push_magnitude = min_push
    if push_direction is None:
        theta = np.random.uniform(0, 2*np.pi)
        push_direction = np.array([np.cos(theta), np.sin(theta), 0.0])
    delta_6d = np.zeros(6)
    delta_6d[:3] = push_direction * push_magnitude
    return delta_6d

def apply_local_pose_delta(current_pose_7d, delta_6d, max_pos=2.0, max_ori=2.0):
    dx = delta_6d[0] * max_pos
    dy = delta_6d[1] * max_pos
    dz = delta_6d[2] * max_pos
    droll = delta_6d[3] * max_ori
    dpitch = delta_6d[4] * max_ori
    dyaw = delta_6d[5] * max_ori
    old_pos = current_pose_7d[:3]
    old_quat = normalize_quaternion(current_pose_7d[3:7])
    old_rot = R.from_quat(old_quat)
    delta_rot = R.from_euler('xyz', [droll, dpitch, dyaw])
    new_rot = old_rot * delta_rot
    new_quat = new_rot.as_quat()
    local_trans = np.array([dx, dy, dz])
    local_trans_world = old_rot.apply(local_trans)
    new_pos = old_pos + local_trans_world
    return np.concatenate([new_pos, new_quat])

# then in compute_recovery_delta_6d, hull_equations will be non‐None:

def normalize_quaternion(q):
    norm = np.linalg.norm(q)
    if norm < 1e-9:
        # fallback
        return np.array([0,0,0,1], dtype=np.float32)
    return q / norm

def find_closest_forward(current_forward, safe_set_orientations):
    """
    Among all forward vectors in safe_set_orientations (N,3),
    pick the one that yields the highest dot product with current_forward.
    """
    cf_norm = np.linalg.norm(current_forward)
    if cf_norm < 1e-9:
        # fallback if current_forward is degenerate
        return safe_set_orientations[0]

    cf_unit = current_forward / cf_norm

    # Normalize all stored forwards
    norms = np.linalg.norm(safe_set_orientations, axis=1, keepdims=True)
    norms[norms < 1e-9] = 1e-9
    forward_unit = safe_set_orientations / norms

    # Dot products with current forward
    dots = forward_unit.dot(cf_unit)
    i_best = np.argmax(dots)
    return forward_unit[i_best]

def compute_orientation_recovery_delta_closest(current_pose_7d, safe_set_orientations, tol_angle=0.05):
    """
    1) Convert current pose to a forward vector (EE x-axis).
    2) Find the closest forward in safe_set_orientations (via dot).
    3) Compute minimal axis-angle rotation to align them.
    4) Zero if below tol_angle or cross is degenerate.
    """
    quat = normalize_quaternion(current_pose_7d[3:7])
    current_forward = R.from_quat(quat).apply([1,0,0])

    # 1) Find the 'closest' forward
    target_forward = find_closest_forward(current_forward, safe_set_orientations)

    # 2) Compute angle
    dot = np.clip(np.dot(current_forward, target_forward), -1.0, 1.0)
    angle = np.arccos(dot)
    if angle < tol_angle:
        return np.zeros(3, dtype=np.float32)

    # 3) Cross => axis, then multiply by angle
    axis = np.cross(current_forward, target_forward)
    norm_axis = np.linalg.norm(axis)
    if norm_axis < 1e-9:
        return np.zeros(3, dtype=np.float32)
    axis /= norm_axis
    print(f"[Orientation Recovery Debug] angle={angle:.4f} rad, axis_norm={norm_axis}")
    return axis * angle

def compute_recovery_delta_6d(
    current_pose_7d,
    hull_equations,
    safe_set_positions,
    safe_set_orientations,
    pos_tol=1e-3,
    margin=0.1,
    tol_angle=0.05,
    recovery_scale=30.0
):
    """
    Single-step Nagumo-like approach with a margin.
    If any plane has a violation > -margin, we nudge deeper.
    """
    if hull_equations is None:
        # fallback to centroid
        centroid = np.mean(safe_set_positions, axis=0)
        delta_pos = recovery_scale * (centroid - current_pose_7d[:3])
    else:
        x = current_pose_7d[:3]
        A = hull_equations[:, :3]
        b = hull_equations[:, 3]
        violations = np.dot(A, x) + b  # shape (m,)

        # Instead of strict > pos_tol, let "active" be those planes that are above -margin
        # meaning if you're within margin of the boundary, you count as "active"
        active = (violations > -pos_tol)

        if not np.any(active):
            # you are well inside by more than margin
            delta_pos = np.zeros(3)
        else:
            # pick the "worst" plane among those
            # i.e., the largest violation among planes above -margin
            # This could be negative but closer to zero than -margin
            # mask_violations = violations.copy()
            # mask_violations[~active] = -9999  # ignore planes that are definitely inside
            # i_star = np.argmax(mask_violations)  # the "worst" plane
            # violation = violations[i_star]
            # Ai = A[i_star]
            # Ai_norm2 = np.dot(Ai, Ai) + 1e-9
            # # nudge
            # delta_pos = -recovery_scale * (violation / Ai_norm2) * Ai
            i_star = np.argmax(violations * active)
            violation = violations[i_star]
            Ai = A[i_star]
            Ai_norm2 = np.dot(Ai, Ai) + 1e-9
            max_step = 0.8  # Example: 1 cm max
            raw_recovery = - recovery_scale * (violation / Ai_norm2) * Ai
            # Then clamp
            step_norm = np.linalg.norm(raw_recovery)
            if step_norm > max_step:
                raw_recovery = raw_recovery * (max_step / step_norm)
            delta_pos = raw_recovery
            # delta_pos = - recovery_scale * (violation / Ai_norm2) * Ai

    # orientation: "closest forward" the same as your sim code
    delta_ori = compute_orientation_recovery_delta_closest(
        current_pose_7d,
        safe_set_orientations,
        tol_angle=tol_angle
    )

    return np.concatenate([delta_pos, delta_ori])





# -------------------------------
# Undo Transform Action (as in sim code)
# -------------------------------
def undo_transform_action(action, rotation_transformer):
    raw_shape = action.shape
    if raw_shape[-1] == 20:
        action = action.reshape(-1, 2, 10)
    d_rot = action.shape[-1] - 4
    pos = action[..., :3]
    rot = action[..., 3:3+d_rot]
    gripper = action[..., -1:]
    rot = rotation_transformer.inverse(rot)
    uaction = np.concatenate([pos, rot, gripper], axis=-1)
    if raw_shape[-1] == 20:
        uaction = uaction.reshape(*raw_shape[:-1], 14)
    return uaction

In [5]:
# -------------------------------
# Rollout Function for Real Robot (Using Converted ee_states_7d for Safe Set Checks)
# -------------------------------
def rollout_diffusion_real_robot(policy, rotation_transformer,
                                 safe_set_positions, safe_set_orientations,
                                 delaunay, centroid, hull_equations,
                                 n_obs_steps=2, max_steps=100, return_imgs=False,
                                 avg_forward=None, push_direction=None, push_after=0,
                                 with_recovery=False, recovery_scale=30.0, trial=0):
    """
    Perform a rollout on the real robot:
      - Uses a frame stacker to build observations.
      - For initial steps (< push_after), uses the diffusion policy directly.
      - Then if the current pose (from "ee_states_7d") is inside the safe set,
        it executes a push phase; otherwise, it uses the diffusion policy.
      - If the resulting pose is unsafe and recovery is enabled, a recovery delta is applied.
    """
    keys_select = ['joint_states', 'ee_states', 'eye_in_hand_rgb', 'gripper_states']
    framestacker = FrameStackForTrans(n_obs_steps)
    obs = get_current_obs()
    obs = framestacker.reset(obs)
    policy.reset()

    push_tol = 0.12
    check_tol = 0.11
    recovery_tol = 0.01
    step_count = 0
    done = False
    imgs = []
    imgs_eye = []
    trajectory = []
    RECOVERY_STEPS = 5
    pushing = True

    while step_count < max_steps and not done:
        # Use the converted version "ee_states_7d" for safe-set checking.
        # obs = get_current_obs()
        current_pose_7d = obs["ee_states_7d"][-1][0]
        # print(current_pose_7d)
        inside, _, _ = is_pose_in_safe_set(
            current_pose_7d, safe_set_positions, safe_set_orientations,
            tol_pos=push_tol, tol_ori_cos=0.9,
            delaunay=delaunay, centroid=centroid, avg_forward=avg_forward
        )

        print(f"{step_count} inside {inside}")

        if step_count < push_after:
            # Initial policy phase: use diffusion policy.
            np_obs_dict = {k: obs[k][None, :] for k in keys_select if k in obs}
            obs_tensor = dict_apply(np_obs_dict, lambda x: torch.from_numpy(x).to(device))
            with torch.no_grad():
                action_dict = policy.predict_action(obs_tensor)
            np_action_dict = dict_apply(action_dict, lambda x: x.detach().cpu().numpy())
            env_action = np_action_dict["action"]
            env_action = undo_transform_action(env_action, rotation_transformer)
            env_action = env_action.squeeze()
            for act in env_action[:1]:
                # print(act)
                robot_interface.control(controller_type=controller_type,
                                          action=act,
                                          controller_cfg=controller_cfg)
                step_count += 1
                new_obs = get_current_obs()
                obs = framestacker.add_new_obs(new_obs)
            trajectory.append((current_pose_7d[:3], inside))
        else:
            if pushing and inside:
                print(f"{step_count} Pushing... inside {inside}")
                delta_6d = compute_push_delta(
                    current_pose_7d, safe_set_positions,
                    safety_margin=push_tol, push_factor=3.0,
                    min_push=0.5, push_direction=push_direction
                )
                new_pose_7d = apply_local_pose_delta(current_pose_7d, delta_6d)
                pos_delta = new_pose_7d[:3] - current_pose_7d[:3]
                quat_current = current_pose_7d[3:]
                quat_new = new_pose_7d[3:]
                quat_current_inv = R.from_quat(quat_current).inv()
                quat_delta = (quat_current_inv * R.from_quat(quat_new)).as_quat()
                env_action = np.concatenate([pos_delta, quat_delta,np.array([0.0])])
                robot_interface.control(controller_type=controller_type,
                                          action=env_action,
                                          controller_cfg=controller_cfg)
                step_count += 1
                new_obs = get_current_obs()
                obs = framestacker.add_new_obs(new_obs)
                trajectory.append((current_pose_7d[:3], inside))
                set_gripper(open=True)
            else:
                print(f"{step_count} Policy Running ... inside {inside}")
                pushing = False
                np_obs_dict = {k: obs[k][None, :] for k in keys_select if k in obs}
                obs_tensor = dict_apply(np_obs_dict, lambda x: torch.from_numpy(x).to(device))
                with torch.no_grad():
                    action_dict = policy.predict_action(obs_tensor)
                np_action_dict = dict_apply(action_dict, lambda x: x.detach().cpu().numpy())
                env_action = np_action_dict["action"]
                env_action = undo_transform_action(env_action, rotation_transformer)
                env_action = env_action.squeeze()
                for act in env_action[:4]:
                    # print(act)
                    # robot_interface.control(controller_type=controller_type,
                    #                           action=act,
                    #                           controller_cfg=controller_cfg)
                    step_count += 1
                    new_obs = get_current_obs()
                    new_pose_stack = new_obs["ee_states_7d"]
                    current_pose_7d = new_pose_stack[-1]
                    inside_after, _, _ = is_pose_in_safe_set(
                        current_pose_7d, safe_set_positions, safe_set_orientations,
                        tol_pos=check_tol, tol_ori_cos=0.9,
                        delaunay=delaunay, centroid=centroid, avg_forward=avg_forward
                    )
                    
                    if (not inside_after) and with_recovery:
                        print(f"{step_count} Recoverying ... inside {inside_after}")
                        recovery_delta_6d = compute_recovery_delta_6d(
                            current_pose_7d,
                            hull_equations,
                            safe_set_positions,
                            safe_set_orientations,
                            tol_angle=0.6, recovery_scale=recovery_scale
                        )
                        # print("Recovery delta :",recovery_delta_6d)
                        # small_delta_6d = recovery_delta_6d / float(RECOVERY_STEPS)
                        recovery_delta_7d= np.concatenate([recovery_delta_6d,act[-1:]],axis=0)
                        
                        robot_interface.control(controller_type=controller_type,
                                                  action=recovery_delta_7d,
                                                  controller_cfg=controller_cfg)
                        time.sleep(0.1)

                        step_count += 1
                        new_obs = get_current_obs()
                        current_pose_7d = new_obs["ee_states_7d"][-1]
                    else:
                        # print("Actual Delta : ",act)
                        robot_interface.control(controller_type=controller_type,
                                              action=act,
                                              controller_cfg=controller_cfg)
                        time.sleep(0.1)
                    obs = framestacker.add_new_obs(new_obs)
                    trajectory.append((current_pose_7d[:3], inside_after))
        if step_count >= max_steps:
            done = True
        if return_imgs:
            obs_img = get_current_obs()
            imgs.append(obs_img['agent_view'])
            imgs_eye.append(obs_img['eye_in_hand_rgb_or'])
    return imgs, imgs_eye, trajectory

def undo_transform_action(action, rotation_transformer):
    raw_shape = action.shape
    if raw_shape[-1] == 20:
        action = action.reshape(-1, 2, 10)
    d_rot = action.shape[-1] - 4
    pos = action[..., :3]
    rot = action[..., 3:3+d_rot]
    gripper = action[..., -1:]
    rot = rotation_transformer.inverse(rot)
    uaction = np.concatenate([pos, rot, gripper], axis=-1)
    if raw_shape[-1] == 20:
        uaction = uaction.reshape(*raw_shape[:-1], 14)
    return uaction

In [264]:
def is_robot_stuck(
    position_history: list,
    iteration_threshold: int = 10,
    net_displacement_threshold: float = 0.005
) -> bool:
    """
    Returns True if the end-effector’s net movement over the last
    `iteration_threshold` steps is below `net_displacement_threshold`.
    Otherwise False.

    position_history: a list (or deque) of recent 3D positions, most recent last.
    """
    # not enough data yet
    if len(position_history) < iteration_threshold:
        return False

    start_pos = np.array(position_history[-iteration_threshold])
    current_pos = np.array(position_history[-1])
    net_disp = np.linalg.norm(current_pos - start_pos)
    return net_disp < net_displacement_threshold

from collections import deque

# -------------------------------
# Rollout Function for Real Robot (Using Converted ee_states_7d for Safe Set Checks)
# -------------------------------
def rollout_diffusion_real_robot_task_cond(policy, rotation_transformer,
                                 safe_set_positions, safe_set_orientations,
                                 delaunay, centroid, hull_equations,
                                 n_obs_steps=2, max_steps=100, return_imgs=False,
                                 avg_forward=None, pushing = False,push_direction=None, push_after=0,
                                 with_recovery=False, recovery_scale=30.0, trial=0):
    """
    Perform a rollout on the real robot:
      - Uses a frame stacker to build observations.
      - For initial steps (< push_after), uses the diffusion policy directly.
      - Then if the current pose (from "ee_states_7d") is inside the safe set,
        it executes a push phase; otherwise, it uses the diffusion policy.
      - If the resulting pose is unsafe and recovery is enabled, a recovery delta is applied.
    """
    keys_select = ['joint_states', 'ee_states', 'eye_in_hand_rgb', 'gripper_states']
    framestacker = FrameStackForTrans(n_obs_steps)
    obs = get_current_obs()
    obs = framestacker.reset(obs)
    policy.reset()

    push_tol = 0.2
    check_tol = 0.1
    recovery_tol = 0.01
    step_count = 0
    done = False
    imgs = []
    imgs_eye = []
    trajectory = []
    RECOVERY_STEPS = 5
    # pushing = False
    first_task_done = False
    full_done = False
    gripper_state = []
    vertical_push = True
    vertical_push_amount = 0.5
    position_history = deque(maxlen=10)

    while step_count < max_steps and not done:
        # Use the converted version "ee_states_7d" for safe-set checking.
        # obs = get_current_obs()
        current_pose_7d = obs["ee_states_7d"][-1][0]
        # print(current_pose_7d)
        inside, _, _ = is_pose_in_safe_set(
            current_pose_7d, safe_set_positions, safe_set_orientations,
            tol_pos=push_tol, tol_ori_cos=0.4,
            delaunay=delaunay, centroid=centroid, avg_forward=avg_forward
        )
        position_history.append(current_pose_7d[:3].copy())

        print(f"{step_count} inside {inside}")
        # stuck check
        stuck = is_robot_stuck(
            list(position_history),
            net_displacement_threshold=0.005
        )

        if step_count < push_after:
            # Initial policy phase: use diffusion policy.
            np_obs_dict = {k: obs[k][None, :] for k in keys_select if k in obs}
            obs_tensor = dict_apply(np_obs_dict, lambda x: torch.from_numpy(x).to(device))
            with torch.no_grad():
                action_dict = policy.predict_action(obs_tensor)
            np_action_dict = dict_apply(action_dict, lambda x: x.detach().cpu().numpy())
            env_action = np_action_dict["action"]
            env_action = undo_transform_action(env_action, rotation_transformer)
            env_action = env_action.squeeze()
            for act in env_action[:4]:
                
                # print(act)
                robot_interface.control(controller_type=controller_type,
                                          action=act,
                                          controller_cfg=controller_cfg)
                step_count += 1
                new_obs = get_current_obs()
                gripper_state = new_obs["gripper_states"]
                obs = framestacker.add_new_obs(new_obs)
            trajectory.append((current_pose_7d[:3], inside))
        else:
            
            if pushing and inside:
                set_gripper(open=True)
                print(f"{step_count} Pushing... inside {inside}")
                delta_6d = compute_push_delta(
                    current_pose_7d, safe_set_positions,
                    safety_margin=push_tol, push_factor=3.0,
                    min_push=0.3, push_direction=push_direction
                )
                new_pose_7d = apply_local_pose_delta(current_pose_7d, delta_6d)
                pos_delta = new_pose_7d[:3] - current_pose_7d[:3]
                quat_current = current_pose_7d[3:]
                quat_new = new_pose_7d[3:]
                quat_current_inv = R.from_quat(quat_current).inv()
                quat_delta = (quat_current_inv * R.from_quat(quat_new)).as_quat()
                env_action = np.concatenate([pos_delta, quat_delta,np.array([0])])
                robot_interface.control(controller_type=controller_type,
                                          action=env_action,
                                          controller_cfg=controller_cfg)
                step_count += 1
                new_obs = get_current_obs()
                obs = framestacker.add_new_obs(new_obs)
                trajectory.append((current_pose_7d[:3], inside))
                # set_gripper(open=True)
            else:
                
                print(f"{step_count} Policy Running ... inside {inside}")
                pushing = False
                np_obs_dict = {k: obs[k][None, :] for k in keys_select if k in obs}
                obs_tensor = dict_apply(np_obs_dict, lambda x: torch.from_numpy(x).to(device))
                with torch.no_grad():
                    action_dict = policy.predict_action(obs_tensor)
                np_action_dict = dict_apply(action_dict, lambda x: x.detach().cpu().numpy())
                env_action = np_action_dict["action"]
                env_action = undo_transform_action(env_action, rotation_transformer)
                env_action = env_action.squeeze()
                for act in env_action:

                    if stuck:
                    # vertical lift only when stuck
                        print("Pushing Upward..")
                        delta_6d = np.zeros(6)
                        delta_6d[2] = vertical_push_amount
                        delta_7d= np.concatenate([delta_6d,act[-1:]],axis=0)
                        robot_interface.control(controller_type=controller_type,
                                              action=delta_7d,
                                              controller_cfg=controller_cfg)
                        time.sleep(0.1)

                    # print(act)
                    # robot_interface.control(controller_type=controller_type,
                    #                           action=act,
                    #                           controller_cfg=controller_cfg)
                    step_count += 1
                    new_obs = get_current_obs()
                    new_pose_stack = new_obs["ee_states_7d"]
                    current_pose_7d = new_pose_stack[-1]
                    inside_after, _, _ = is_pose_in_safe_set(
                        current_pose_7d, safe_set_positions, safe_set_orientations,
                        tol_pos=check_tol, tol_ori_cos=0.01,
                        delaunay=delaunay, centroid=centroid, avg_forward=avg_forward
                    )
                    if step_count >60 and new_obs["gripper_states"][0]<0.02:
                        first_task_done = True
                        print("First Task has been completed..")
                    else:
                        trajectory.append((current_pose_7d[:3], inside))

                    if first_task_done and new_obs["gripper_states"][0]>0.06:
                        full_done = True
                        break

                    

                    
                    if (not inside_after) and with_recovery and (not first_task_done):
                        print(f"{step_count} Recoverying ... inside {inside_after}")
                        recovery_delta_6d = compute_recovery_delta_6d(
                            current_pose_7d,
                            hull_equations,
                            safe_set_positions,
                            safe_set_orientations,
                            pos_tol = check_tol,
                            tol_angle=0.01, recovery_scale=recovery_scale
                        )
                        # print("Recovery delta :",recovery_delta_6d)
                        # small_delta_6d = recovery_delta_6d / float(RECOVERY_STEPS)
                        recovery_delta_7d= np.concatenate([recovery_delta_6d,act[-1:]],axis=0)
                        
                        robot_interface.control(controller_type=controller_type,
                                                  action=recovery_delta_7d,
                                                  controller_cfg=controller_cfg)
                        time.sleep(0.1)

                        step_count += 1
                        new_obs = get_current_obs()
                        current_pose_7d = new_obs["ee_states_7d"][-1]
                    else:
                        # print("Actual Delta : ",act)
                        robot_interface.control(controller_type=controller_type,
                                              action=act,
                                              controller_cfg=controller_cfg)
                        time.sleep(0.1)
                    
                    if return_imgs:
                        obs_img = get_current_obs()
                        imgs.append(obs_img['agent_view'])
                        imgs_eye.append(obs_img['eye_in_hand_rgb_or'])
                    
                    
                    obs = framestacker.add_new_obs(new_obs)
                    if first_task_done:
                        trajectory.append((current_pose_7d[:3], None))
                if full_done:
                    print("Full Task Completed..")
                    # -------------------------------
                    # Reset Robot and Prepare for Rollout
                    # -------------------------------
                    reset_joint_positions = [-0.048,0.07,0.007,-1.429,-0.007,1.548,0.72 ]
                    set_gripper(open=True)
                    reset_joints_to(robot_interface, reset_joint_positions)
                    set_gripper(open=True)
                    break
                    
                if not first_task_done:
                    trajectory.append((current_pose_7d[:3], inside))
        if step_count >= max_steps:
            done = True
        
    return imgs, imgs_eye, trajectory

def undo_transform_action(action, rotation_transformer):
    raw_shape = action.shape
    if raw_shape[-1] == 20:
        action = action.reshape(-1, 2, 10)
    d_rot = action.shape[-1] - 4
    pos = action[..., :3]
    rot = action[..., 3:3+d_rot]
    gripper = action[..., -1:]
    rot = rotation_transformer.inverse(rot)
    uaction = np.concatenate([pos, rot, gripper], axis=-1)
    if raw_shape[-1] == 20:
        uaction = uaction.reshape(*raw_shape[:-1], 14)
    return uaction

In [99]:
# -------------------------------
# Load Diffusion Policy Checkpoint and Setup Policy
# -------------------------------
checkpoint_path = '/home/franka_deoxys/riad/candy/multiple/after_train_500_epochs.ckpt'
# checkpoint_path = '/home/franka_deoxys/after_train_500_epochs.ckpt'
with open(checkpoint_path, 'rb') as f:
    payload = torch.load(f, pickle_module=dill)
cfg = payload['cfg']
workspace = TrainDiffusionUnetHybridWorkspace(cfg, output_dir=None)
workspace.load_payload(payload, exclude_keys=None, include_keys=None)
policy = workspace.model
if getattr(cfg.training, "use_ema", False):
    policy = workspace.ema_model
policy.to(device)
policy.eval()
print("Diffusion policy loaded and set to eval mode.")

# -------------------------------
# Initialize Rotation Transformer
# -------------------------------
abs_action = True
rotation_transformer = None
if abs_action:
    rotation_transformer = RotationTransformer('axis_angle', 'rotation_6d')



using obs modality: low_dim with keys: ['ee_states', 'gripper_states', 'joint_states']
using obs modality: rgb with keys: ['eye_in_hand_rgb']
using obs modality: depth with keys: []
using obs modality: scan with keys: []
Diffusion params: 2.528022e+08
Vision params: 1.119709e+07
Diffusion policy loaded and set to eval mode.


In [160]:
def get_push_direction(seed):
    push_rng = np.random.RandomState(seed)
    theta = push_rng.uniform(0, 2*np.pi)
    return np.array([np.cos(theta), np.sin(theta), 0.0])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Load safe set (assume NPZ file contains safe_set_positions, safe_set_orientations, hull_equations, etc.)
safe_set_positions, safe_set_orientations, hull_equations, _ = load_safe_set("/home/franka_deoxys/fr3/safe_set_6d_real_robot_additionals.npz")
global_delaunay = Delaunay(safe_set_positions)
global_centroid = np.mean(safe_set_positions, axis=0)
avg_forward = np.mean(safe_set_orientations, axis=0)
avg_forward /= np.linalg.norm(avg_forward)
push_direction = get_push_direction(80)

# Optional: Initialize Rotation Transformer (for absolute actions)
abs_action = True
rotation_transformer = None
if abs_action:
    rotation_transformer = RotationTransformer('axis_angle', 'rotation_6d')

Using device: cuda


In [268]:



# -------------------------------
# Reset Robot and Prepare for Rollout
# -------------------------------
reset_joint_positions = [-0.048,0.07,0.007,-1.429,-0.007,1.548,0.72 ]
set_gripper(open=True)
reset_joints_to(robot_interface, reset_joint_positions)
set_gripper(open=True)

JOINT_POSITION


In [265]:


# # Reset robot
# set_gripper(open=True)
# reset_joints_to(robot_interface, reset_joint_positions)
# set_gripper(open=True)
policy.reset()

# Run Real Robot Rollout
imgs,imgs_eyes, trajectory = rollout_diffusion_real_robot_task_cond(
    policy, rotation_transformer,
    safe_set_positions, safe_set_orientations,
    global_delaunay, global_centroid, hull_equations,
    n_obs_steps=2, max_steps=300, return_imgs=True,
    avg_forward=avg_forward,pushing=False, push_direction=push_direction, push_after=0,
    with_recovery=False, recovery_scale=60, trial=0
)
# if imgs:
#     video_filename = "real_robot_trial_output.mp4"
#     imageio.mimwrite(video_filename, imgs, fps=20, quality=8)
#     print("Saved video:", video_filename)

robot_interface.close()

0 inside True
0 Policy Running ... inside True
8 inside True
8 Policy Running ... inside True
16 inside True
16 Policy Running ... inside True
24 inside True
24 Policy Running ... inside True
32 inside False
32 Policy Running ... inside False
40 inside False
40 Policy Running ... inside False
48 inside False
48 Policy Running ... inside False
56 inside True
56 Policy Running ... inside True
64 inside True
64 Policy Running ... inside True
72 inside False
72 Policy Running ... inside False
80 inside True
80 Policy Running ... inside True
88 inside True
88 Policy Running ... inside True
96 inside True
96 Policy Running ... inside True
104 inside True
104 Policy Running ... inside True
112 inside True
112 Policy Running ... inside True
120 inside True
120 Policy Running ... inside True
128 inside False
128 Policy Running ... inside False
136 inside False
136 Policy Running ... inside False
144 inside False
144 Policy Running ... inside False
152 inside False
152 Policy Running ... inside 

In [266]:
import numpy as np
from scipy.spatial import ConvexHull
import plotly.graph_objects as go

def load_safe_set_positions(file_path):
    """
    Load safe set positions from the given NPZ file.
    It expects the key "safe_set_positions" (or falls back to the first 3 columns of "safe_set").
    """
    data = np.load(file_path)
    if "safe_set_positions" in data:
        safe_set_positions = data["safe_set_positions"]
    else:
        safe_set = data["safe_set"]
        safe_set_positions = safe_set[:, :3]
    return safe_set_positions
def plot_trajectory_on_hull_plotly(safe_set_positions, trajectory):
    """
    Plot the convex hull of the safe set positions (as a translucent red mesh) 
    with the trajectory (colored markers based on safe set membership).
    """
    # Compute convex hull of the safe set positions
    hull = ConvexHull(safe_set_positions)
    
    # Extract x, y, z coordinates
    x = safe_set_positions[:, 0]
    y = safe_set_positions[:, 1]
    z = safe_set_positions[:, 2]
    
    # Use hull simplices to define the mesh triangles
    mesh_trace = go.Mesh3d(
        x=x,
        y=y,
        z=z,
        i=hull.simplices[:, 0],
        j=hull.simplices[:, 1],
        k=hull.simplices[:, 2],
        color='red',
        opacity=0.3,
        name='Safe Set Hull'
    )
    
    # Extract trajectory positions and colors
    traj_positions = np.array([pt for pt, inside in trajectory])
    # traj_colors = ['green' if inside ==True else 'red' for _, inside in trajectory]
    traj_colors = []
    for _, inside in trajectory:
        if inside == None:
            continue
        elif inside == True:
            traj_colors.append('green')
        else:
            traj_colors.append('red')
        # print(inside)
    # traj_colors = [
    # 'green' if inside is True else ('white' if inside is None else 'red')
    # for _, inside in trajectory
    # ]
    # print(traj_colors)
    
    # Create a 3D scatter plot with point-wise colors
    traj_trace = go.Scatter3d(
        x=traj_positions[:, 0],
        y=traj_positions[:, 1],
        z=traj_positions[:, 2],
        mode='lines+markers',
        line=dict(color='gray', width=2),  # line color neutral
        marker=dict(size=5, color=traj_colors),
        name='Trajectory'
    )
    
    # Build the figure
    fig = go.Figure(data=[mesh_trace, traj_trace])
    fig.update_layout(
        title='Trajectory with Safe Set Membership',
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z'
        )
    )
    fig.show()


safe_set_filepath = "/home/franka_deoxys/fr3/safe_set_6d_real_robot_additionals.npz"

safe_set_positions = load_safe_set_positions(safe_set_filepath)

plot_trajectory_on_hull_plotly(safe_set_positions, trajectory)


In [267]:
# 1) Precompute 2D convex hull in the XY plane
pts2d = safe_set_positions[:, :2]                  # (N,2)
hull2d = ConvexHull(pts2d)
hull_pts = pts2d[hull2d.vertices]                  # (M,2)

# 2) Compute world→image mapping
H, W = 512, 512                                   # desired size
min_xy = pts2d.min(axis=0)
max_xy = pts2d.max(axis=0)
margin = 0.05 * (max_xy - min_xy)                  # 5% padding
min_xy -= margin
max_xy += margin
scale = np.array([W/(max_xy[0]-min_xy[0]), H/(max_xy[1]-min_xy[1])])

def world_to_img(pt):
    """
    Map a 3D point pt = (x,y,z) to integer pixel (u,v) in the 2D top-down view.
    """
    # only use x,y
    xy = pt[:2]
    uv = (xy - min_xy) * scale
    u = int(np.clip(uv[0], 0, W-1))
    v = int(np.clip(H-1 - uv[1], 0, H-1))  # flip Y so larger y is lower in image
    return (u, v)


# 3) Render one frame of the safe‑set + trajectory in 2D
def render_safe_set_frame_cv2(trajectory, idx):
    img = np.zeros((H, W, 3), dtype=np.uint8)
    hull_img = np.array([world_to_img(p) for p in hull_pts], np.int32)
    cv2.fillPoly(img, [hull_img], (0,0,255))

    pts = [world_to_img(p) for p,_ in trajectory[:idx+1]]
    if len(pts)>1:
        cv2.polylines(img, [np.array(pts, np.int32)], False, (255,255,255), 1)
    for p,inside in trajectory[:idx+1]:
        u,v = world_to_img(p)
        if inside == None:
            color = (255,255,255)
            # cv2.circle(img, (u,v), 5, color, -1)
        elif inside == True:
            color = (0,255,0)
            cv2.circle(img, (u,v), 5, color, -1)
        else:
            color = (255,0,0)
            cv2.circle(img, (u,v), 5, color, -1)
        # color = (0,255,0) if inside else (255,0,0)
        
    return img


# 4) Side‑by‑side helper
def side_by_side(left, right):
    h1,w1 = left.shape[:2]
    h2,w2 = right.shape[:2]
    if h1!=h2:
        right = cv2.resize(right, (int(w2*(h1/h2)), h1))
    return np.hstack((left, right))

# 5) Write out the video
def write_side_by_side_video(agent_frames, traj_frames, out_path, fps=15):
    print(len(agent_frames), len(traj_frames))
    n = min(len(agent_frames), len(traj_frames))
    combo = [side_by_side(agent_frames[i], traj_frames[i]) for i in range(n)]
    imageio.mimwrite(out_path, combo, fps=fps)
    print(f"Saved side‑by‑side video to {out_path}")

# render the safe‑set side
safe_set_frames = [render_safe_set_frame_cv2(trajectory, i) for i in range(len(trajectory))]

# write the final video
write_side_by_side_video(imgs, safe_set_frames, "baseline_fail_with_agent_view_push_10_side_by_side_full_task.mp4", fps=5)

304 342
Saved side‑by‑side video to baseline_fail_with_agent_view_push_10_side_by_side_full_task.mp4


In [None]:
import cv2
import plotly.graph_objects as go
import plotly.io as pio
from scipy.spatial import ConvexHull

# 1) Precompute the convex hull of your safe set
hull = ConvexHull(safe_set_positions)
hull_pts = safe_set_positions
simplices = hull.simplices

# 2) Helper to render a Plotly 3D frame up to index i
def render_3d_frame(safe_set_positions, simplices, trajectory, idx, width, height):
    # safe set mesh
    mesh = go.Mesh3d(
        x=safe_set_positions[:,0],
        y=safe_set_positions[:,1],
        z=safe_set_positions[:,2],
        i=simplices[:,0], j=simplices[:,1], k=simplices[:,2],
        color='red', opacity=0.3
    )
    # trajectory so far
    pts = np.array([p for p,_ in trajectory[:idx+1]])
    colors = ['green' if inside else 'red' for _,inside in trajectory[:idx+1]]
    traj = go.Scatter3d(
        x=pts[:,0], y=pts[:,1], z=pts[:,2],
        mode='lines+markers',
        line=dict(color='white', width=2),
        marker=dict(size=4, color=colors)
    )
    fig = go.Figure([mesh, traj])
    fig.update_layout(
        scene=dict(
            xaxis=dict(showbackground=False),
            yaxis=dict(showbackground=False),
            zaxis=dict(showbackground=False),
            aspectmode='data'
        ),
        width=width, height=height,
        margin=dict(l=0,r=0,b=0,t=0)
    )
    # export to PNG
    png = pio.to_image(fig, format='png')
    arr = np.frombuffer(png, np.uint8)
    im = cv2.imdecode(arr, cv2.IMREAD_COLOR)
    return cv2.cvtColor(im, cv2.COLOR_BGR2RGB)

# 3) Write out side‑by‑side video
def make_side_by_side_3d_video(agent_frames, trajectory, out_path, fps=15):
    # assume agent_frames is list of HxWx3 RGB arrays
    H, W = agent_frames[0].shape[:2]
    # OpenCV VideoWriter needs BGR
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    # side‑by‑side width = W + W, height = H
    writer = cv2.VideoWriter(out_path, fourcc, fps, (W*2, H))
    for i, frame in enumerate(agent_frames):
        # render the 3D plot at this timestep
        plot_img = render_3d_frame(safe_set_positions, simplices, trajectory, i, W, H)
        # combine left (agent) and right (plot)
        combo = np.hstack((frame, plot_img))
        writer.write(cv2.cvtColor(combo, cv2.COLOR_RGB2BGR))
    writer.release()
    print(f"Saved side‑by‑side 3D video to {out_path}")

# ———— Append this after your rollout call ————
# imgs_agent, imgs_eye, trajectory = rollout_push_out_then_recover(..., return_imgs=True)
make_side_by_side_3d_video(imgs, trajectory, "agent_push_2_side_by_side_3d_2.mp4", fps=5)


KeyboardInterrupt: 