In [None]:
pip install transformers torch pillow numpy mujoco




In [None]:
!git clone https://github.com/DorianAtSchool/Franka_table.git

fatal: destination path 'Franka_table' already exists and is not an empty directory.


In [None]:
import numpy as np
from PIL import Image
import torch
from transformers import AutoModelForVision2Seq, AutoProcessor
from Franka_table.environments.franka_4robots_env import FrankaTable4RobotsEnv

In [None]:
# ============================================================================
# 1. Setup Environment
# ============================================================================
print("Setting up Franka 4-Robots environment...")
env = FrankaTable4RobotsEnv(
    mjcf_path="../scenes/scene_4robots.xml",
    render_mode="rgb_array"  # Use rgb_array for image capture
)

Setting up Franka 4-Robots environment...


ValueError: ParseXML: Error opening file 'franka_table/scenes/franka_emika_panda/scene_4pandas_table.xml': No such file or directory

In [None]:
# ============================================================================
# 2. Load OpenVLA Model (GPU Optimized)
# ============================================================================
print("Loading OpenVLA model with GPU optimizations...")

# Check GPU availability
if not torch.cuda.is_available():
    raise RuntimeError("CUDA is not available! Please check your GPU setup.")

device = "cuda"
print(f"✓ Using device: {device}")
print(f"✓ GPU: {torch.cuda.get_device_name(0)}")
print(f"✓ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
print(f"✓ Initial GPU Memory Allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")

processor = AutoProcessor.from_pretrained("openvla/openvla-7b", trust_remote_code=True)

vla = AutoModelForVision2Seq.from_pretrained(
    "openvla/openvla-7b",
    torch_dtype=torch.bfloat16,
    device_map="auto",  # Automatic GPU placement
    low_cpu_mem_usage=True,
    trust_remote_code=True,
    attn_implementation="flash_attention_2",  # Flash Attention 2 for speed (requires flash-attn package)
)

# Verify model is on GPU
print(f"\n✓ Model loaded successfully!")
print(f"✓ Model device map: {vla.hf_device_map if hasattr(vla, 'hf_device_map') else 'N/A'}")
print(f"✓ Model dtype: {vla.dtype}")
print(f"✓ GPU Memory After Loading: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
print(f"✓ GPU Memory Reserved: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB")

# Test that model parameters are on GPU
sample_param = next(vla.parameters())
print(f"✓ Sample parameter device: {sample_param.device}")
print(f"✓ Sample parameter dtype: {sample_param.dtype}")

if not sample_param.is_cuda:
    raise RuntimeError("Model parameters are NOT on GPU!")

In [None]:
# ============================================================================
# 3. Helper Functions
# ============================================================================
def get_robot_state_7d(env, robot_idx=0):
    """
    Extract 7D state (position + quaternion) for a specific robot's end-effector.

    Args:
        env: FrankaTable4RobotsEnv instance
        robot_idx: Which robot (0-3)

    Returns:
        np.ndarray: [x, y, z, qw, qx, qy, qz]
    """
    prefix = env.robot_prefixes[robot_idx]
    gripper_site_name = f"{prefix}gripper_site"

    # Get site ID
    import mujoco
    gripper_site_id = mujoco.mj_name2id(
        env.model,
        mujoco.mjtObj.mjOBJ_SITE,
        gripper_site_name
    )

    if gripper_site_id >= 0:
        # Get position
        position = env.data.site_xpos[gripper_site_id].copy()  # [x, y, z]

        # Get orientation (quaternion in MuJoCo format: [w, x, y, z])
        quaternion = env.data.site_xquat[gripper_site_id].copy()  # [w, x, y, z]

        # Combine into 7D state
        state_7d = np.concatenate([position, quaternion])
        return state_7d
    else:
        # Fallback if site not found
        return np.zeros(7)

def get_observation(env, robot_idx=0, camera_name=None):
    """
    Get RGB image and robot state from environment.

    Args:
        env: FrankaTable4RobotsEnv instance
        robot_idx: Which robot to get state for (0-3)
        camera_name: Optional camera name for specific view

    Returns:
        image: PIL Image
        state_7d: np.ndarray of shape (7,) - [x, y, z, qw, qx, qy, qz]
    """
    # Get RGB image
    if camera_name:
        rgb_array = env.render_camera(camera_name, width=640, height=480)
    else:
        rgb_array = env.render()

    image = Image.fromarray(rgb_array)

    # Get 7D state for the specified robot
    state_7d = get_robot_state_7d(env, robot_idx)

    return image, state_7d

def format_prompt(task_description):
    """Format the prompt for OpenVLA"""
    return f"In: What action should the robot take to {task_description}?\nOut:"

def vla_action_to_env_action(vla_action, robot_idx, env):
    """
    Convert VLA 7-DoF action to environment's 32-actuator action.

    VLA action: [dx, dy, dz, droll, dpitch, dyaw, gripper]
    Env action: 32 values (8 per robot × 4 robots)

    This is a simplified mapping - you may need more sophisticated IK.

    Args:
        vla_action: np.ndarray of shape (7,) from VLA
        robot_idx: Which robot this action is for (0-3)
        env: FrankaTable4RobotsEnv instance

    Returns:
        np.ndarray of shape (32,) for environment
    """
    # Start with current control values (hold other robots steady)
    full_action = env.data.ctrl.copy()

    # Extract components from VLA action
    delta_pos = vla_action[:3]  # [dx, dy, dz]
    delta_rot = vla_action[3:6]  # [droll, dpitch, dyaw]
    gripper = vla_action[6]  # gripper command

    # Get current robot joint positions
    qpos_start = 7 + robot_idx * 9
    current_joints = env.data.qpos[qpos_start:qpos_start+7].copy()

    # Simple approach: small joint space movements
    # For a real implementation, you'd use inverse kinematics here
    # This just applies small deltas to current joint positions
    joint_deltas = np.zeros(7)
    joint_deltas[0] = delta_pos[0] * 0.1  # Scale down for stability
    joint_deltas[1] = delta_pos[1] * 0.1
    joint_deltas[2] = delta_pos[2] * 0.1
    joint_deltas[3] = delta_rot[0] * 0.1
    joint_deltas[4] = delta_rot[1] * 0.1
    joint_deltas[5] = delta_rot[2] * 0.1

    # Apply to the specific robot's actuators
    ctrl_start = robot_idx * 8
    full_action[ctrl_start:ctrl_start+7] = current_joints + joint_deltas

    # Set gripper (maps -1 to 1 → 0 to 255)
    gripper_value = (gripper + 1) * 127.5  # Map [-1, 1] to [0, 255]
    full_action[ctrl_start+7] = np.clip(gripper_value, 0, 255)

    return full_action

In [None]:
# ============================================================================
# 4. Main Control Loop
# ============================================================================
def run_episode(task_description="move the blue cube to the green goal",
                max_steps=200,
                robot_idx=0,
                camera_name=None):
    """
    Run one episode with VLA control.

    Args:
        task_description: Task instruction for VLA
        max_steps: Maximum steps per episode
        robot_idx: Which robot to control (0-3)
        camera_name: Optional specific camera view
    """
    obs, info = env.reset()
    prompt = format_prompt(task_description)

    print(f"\nTask: {task_description}")
    print(f"Controlling Robot {robot_idx + 1}")
    print(f"Running for up to {max_steps} steps...\n")

    for step in range(max_steps):
        # Get current observation
        image, state_7d = get_observation(env, robot_idx, camera_name)

        # Prepare inputs for VLA
        inputs = processor(prompt, image).to(device, dtype=torch.bfloat16)

        # Verify inputs are on GPU (optional debug check)
        # print(f"Input tensors device: {inputs['pixel_values'].device}")

        # Predict action (7-DoF: delta_pos[3] + delta_rot[3] + gripper[1])
        with torch.no_grad():
            vla_action = vla.predict_action(
                **inputs,
                unnorm_key="bridge_orig",
                do_sample=False
            )

        # Verify output is from GPU
        if step == 0:
            print(f"✓ VLA output device: {vla_action.device}")
            print(f"✓ GPU Memory During Inference: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")

        # Convert to numpy
        vla_action = vla_action.cpu().numpy()
        if vla_action.ndim > 1:
            vla_action = vla_action[0]

        # Convert VLA action to full environment action
        env_action = vla_action_to_env_action(vla_action, robot_idx, env)

        # Step environment
        obs, reward, terminated, truncated, info = env.step(env_action)

        # Print progress
        if step % 20 == 0:
            obj_pos = info['object_position']
            goal_dist = info['object_to_goal_distance']
            print(f"Step {step:3d}: obj=[{obj_pos[0]:.2f},{obj_pos[1]:.2f},{obj_pos[2]:.2f}], "
                  f"goal_dist={goal_dist:.3f}, reward={reward:.2f}")

        # Check success
        if info.get('success', False):
            print(f"\n✓ SUCCESS! Object reached goal at step {step}")
            break

        # Check if episode is done
        if terminated or truncated:
            print(f"\nEpisode terminated at step {step}")
            break

    return step, info

In [None]:
# ============================================================================
# 5. Multi-Robot Parallel Inference (GPU Accelerated)
# ============================================================================
def run_episode_multi_robot(task_description="move the blue cube to the green goal",
                            max_steps=200,
                            robot_indices=[0, 1, 2, 3],
                            camera_name=None):
    """
    Run episode controlling multiple robots simultaneously with batched GPU inference.

    Args:
        task_description: Task instruction for VLA
        max_steps: Maximum steps per episode
        robot_indices: List of robot indices to control
        camera_name: Optional specific camera view
    """
    obs, info = env.reset()
    prompt = format_prompt(task_description)

    print(f"\nTask: {task_description}")
    print(f"Controlling Robots: {[i+1 for i in robot_indices]}")
    print(f"Running for up to {max_steps} steps with batched GPU inference...\n")

    for step in range(max_steps):
        # Collect observations for all robots
        images = []
        states_7d = []
        for robot_idx in robot_indices:
            image, state_7d = get_observation(env, robot_idx, camera_name)
            images.append(image)
            states_7d.append(state_7d)

        # Batch process all robots on GPU
        # Create batch of inputs
        batch_inputs = processor(
            [prompt] * len(robot_indices),  # Same prompt for all robots
            images,
            return_tensors="pt"
        ).to(device, dtype=torch.bfloat16)

        # Single batched forward pass on GPU (efficient!)
        with torch.no_grad():
            vla_actions = vla.predict_action(
                **batch_inputs,
                unnorm_key="bridge_orig",
                do_sample=False
            )

        # Convert to numpy
        vla_actions = vla_actions.cpu().numpy()
        if vla_actions.ndim == 1:
            vla_actions = vla_actions.reshape(1, -1)

        # Start with current control values
        full_action = env.data.ctrl.copy()

        # Apply each robot's action
        for i, robot_idx in enumerate(robot_indices):
            vla_action = vla_actions[i]

            # Extract components
            delta_pos = vla_action[:3]
            delta_rot = vla_action[3:6]
            gripper = vla_action[6]

            # Get current joints
            qpos_start = 7 + robot_idx * 9
            current_joints = env.data.qpos[qpos_start:qpos_start+7].copy()

            # Apply deltas
            joint_deltas = np.zeros(7)
            joint_deltas[0] = delta_pos[0] * 0.1
            joint_deltas[1] = delta_pos[1] * 0.1
            joint_deltas[2] = delta_pos[2] * 0.1
            joint_deltas[3] = delta_rot[0] * 0.1
            joint_deltas[4] = delta_rot[1] * 0.1
            joint_deltas[5] = delta_rot[2] * 0.1

            # Update action
            ctrl_start = robot_idx * 8
            full_action[ctrl_start:ctrl_start+7] = current_joints + joint_deltas
            full_action[ctrl_start+7] = np.clip((gripper + 1) * 127.5, 0, 255)

        # Step environment once with all robot actions
        obs, reward, terminated, truncated, info = env.step(full_action)

        # Print progress
        if step % 20 == 0:
            obj_pos = info['object_position']
            goal_dist = info['object_to_goal_distance']
            gripper_dists = info.get('gripper_distances', [])
            avg_gripper = np.mean(gripper_dists) if gripper_dists else 0
            print(f"Step {step:3d}: obj=[{obj_pos[0]:.2f},{obj_pos[1]:.2f},{obj_pos[2]:.2f}], "
                  f"goal_dist={goal_dist:.3f}, avg_gripper_dist={avg_gripper:.3f}, reward={reward:.2f}")

        # Check success
        if info.get('success', False):
            print(f"\n✓ SUCCESS! Object reached goal at step {step}")
            break

        # Check if episode is done
        if terminated or truncated:
            print(f"\nEpisode terminated at step {step}")
            break

    return step, info