In [5]:
!pip install accelerate

Collecting accelerate
  Obtaining dependency information for accelerate from https://files.pythonhosted.org/packages/9f/d2/c581486aa6c4fbd7394c23c47b83fa1a919d34194e16944241daf9e762dd/accelerate-1.12.0-py3-none-any.whl.metadata
  Downloading accelerate-1.12.0-py3-none-any.whl.metadata (19 kB)
Downloading accelerate-1.12.0-py3-none-any.whl (380 kB)
   ---------------------------------------- 0.0/380.9 kB ? eta -:--:--
   - -------------------------------------- 10.2/380.9 kB ? eta -:--:--
   ------ --------------------------------- 61.4/380.9 kB 1.1 MB/s eta 0:00:01
   -------------------------- ------------- 256.0/380.9 kB 2.6 MB/s eta 0:00:01
   ---------------------------------------- 380.9/380.9 kB 3.4 MB/s eta 0:00:00
Installing collected packages: accelerate
Successfully installed accelerate-1.12.0



[notice] A new release of pip is available: 23.2.1 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


In [None]:
import gymnasium as gym
import numpy as np
from PIL import Image
import torch
from transformers import AutoModelForVision2Seq, AutoProcessor
import mujoco


import sys
from pathlib import Path

# Add franka_table parent directory to path
sys.path.insert(0, str(Path().resolve().parent))

from environments.franka_4robots_env import FrankaTable4RobotsEnv

In [2]:
torch.cuda.empty_cache()  # Clear any cached memory
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Current device: {torch.cuda.current_device()}")
print(f"Device count: {torch.cuda.device_count()}")

CUDA available: False


AssertionError: Torch not compiled with CUDA enabled

In [7]:
# ============================================================================
# 1. Setup Environment
# ============================================================================
# Using a robot manipulation environment (e.g., Franka Panda)
# Replace with your specific environment
# env = gym.make(
#     "scenes/scene_4robots.xml",  # custom MuJoCo env
#     render_mode="rgb_array"
# )

env_scene = "../scenes/scene_4robots.xml"
render_mode = "rgb_array"
env = FrankaTable4RobotsEnv(mjcf_path=env_scene, render_mode=render_mode)

Franka Table 4-Robots Environment initialized:
  - Number of robots: 4
  - Number of actuators: 32 (8 per robot)
  - Number of bodies: 47
  - Number of joints (nq): 43
  - Number of velocities (nv): 42
  - Observation dimension: 80
  - Action dimension: 32


In [None]:
sys.path.insert(0, str(Path().resolve().parent.parent.parent))

from lerobot.src.common.policies.smolvla.modeling_smolvla import SmolVLAPolicy
policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")

ModuleNotFoundError: No module named 'lerobot.common'

In [None]:
# ============================================================================
# 2. Load OpenVLA Model (GPU Optimized)
# ============================================================================
print("Loading OpenVLA model with GPU optimizations...")

# Set device
device = "cuda" if torch.cuda.is_available() else ValueError("CUDA device not found. A GPU is required to run this model.")
print(f"Using device: {device}")

if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct", trust_remote_code=True)

vla = AutoModelForVision2Seq.from_pretrained(
    "HuggingFaceTB/SmolVLM-Instruct",
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True,
    attn_implementation="flash_attention_2",  # Flash Attention 2 for speed (requires flash-attn package)
)

print(f"Model loaded successfully on {device}!")
print(f"Model device: {vla.device}")
print(f"Model dtype: {vla.dtype}")

Loading OpenVLA model with GPU optimizations...
Using device: CUDA device not found. A GPU is required to run this model.


In [None]:
# ============================================================================
# 3. Helper Functions
# ============================================================================

def get_robot_state_7d(env, robot_idx=0):
    """
    Extract 7D state (position + quaternion) for a specific robot's end-effector.
    
    Args:
        env: FrankaTable4RobotsEnv instance
        robot_idx: Which robot (0-3)
    
    Returns:
        np.ndarray: [x, y, z, qw, qx, qy, qz]
    """
    prefix = env.robot_prefixes[robot_idx]
    gripper_site_name = f"{prefix}gripper_site"
    
    # Get site ID
    import mujoco
    gripper_site_id = mujoco.mj_name2id(
        env.model, 
        mujoco.mjtObj.mjOBJ_SITE, 
        gripper_site_name
    )
    
    if gripper_site_id >= 0:
        # Get position
        position = env.data.site_xpos[gripper_site_id].copy()  # [x, y, z]
        
        # Get orientation (quaternion in MuJoCo format: [w, x, y, z])
        quaternion = env.data.site_xquat[gripper_site_id].copy()  # [w, x, y, z]
        
        # Combine into 7D state
        state_7d = np.concatenate([position, quaternion])
        return state_7d
    else:
        # Fallback if site not found
        return np.zeros(7)

def get_observation(env, robot_idx=0, camera_name=None):
    """
    Get RGB image and robot state from environment.
    
    Args:
        env: FrankaTable4RobotsEnv instance
        robot_idx: Which robot to get state for (0-3)
        camera_name: Optional camera name for specific view
    
    Returns:
        image: PIL Image
        state_7d: np.ndarray of shape (7,) - [x, y, z, qw, qx, qy, qz]
    """
    # Get RGB image
    if camera_name:
        rgb_array = env.render_camera(camera_name, width=640, height=480)
    else:
        rgb_array = env.render()
    
    image = Image.fromarray(rgb_array)
    
    # Get 7D state for the specified robot
    state_7d = get_robot_state_7d(env, robot_idx)
    
    return image, state_7d

def format_prompt(task_description):
    """Format the prompt for OpenVLA"""
    return f"In: What action should the robot take to {task_description}?\nOut:"

def vla_action_to_env_action(vla_action, robot_idx, env):
    """
    Convert VLA 7-DoF action to environment's 32-actuator action.
    
    VLA action: [dx, dy, dz, droll, dpitch, dyaw, gripper]
    Env action: 32 values (8 per robot × 4 robots)
    
    This is a simplified mapping - you may need more sophisticated IK.
    
    Args:
        vla_action: np.ndarray of shape (7,) from VLA
        robot_idx: Which robot this action is for (0-3)
        env: FrankaTable4RobotsEnv instance
    
    Returns:
        np.ndarray of shape (32,) for environment
    """
    # Start with current control values (hold other robots steady)
    full_action = env.data.ctrl.copy()
    
    # Extract components from VLA action
    delta_pos = vla_action[:3]  # [dx, dy, dz]
    delta_rot = vla_action[3:6]  # [droll, dpitch, dyaw]
    gripper = vla_action[6]  # gripper command
    
    # Get current robot joint positions
    qpos_start = 7 + robot_idx * 9
    current_joints = env.data.qpos[qpos_start:qpos_start+7].copy()
    
    # Simple approach: small joint space movements
    # For a real implementation, you'd use inverse kinematics here
    # This just applies small deltas to current joint positions
    joint_deltas = np.zeros(7)
    joint_deltas[0] = delta_pos[0] * 0.1  # Scale down for stability
    joint_deltas[1] = delta_pos[1] * 0.1
    joint_deltas[2] = delta_pos[2] * 0.1
    joint_deltas[3] = delta_rot[0] * 0.1
    joint_deltas[4] = delta_rot[1] * 0.1
    joint_deltas[5] = delta_rot[2] * 0.1
    
    # Apply to the specific robot's actuators
    ctrl_start = robot_idx * 8
    full_action[ctrl_start:ctrl_start+7] = current_joints + joint_deltas
    
    # Set gripper (maps -1 to 1 → 0 to 255)
    gripper_value = (gripper + 1) * 127.5  # Map [-1, 1] to [0, 255]
    full_action[ctrl_start+7] = np.clip(gripper_value, 0, 255)
    
    return full_action

In [None]:
# ============================================================================
# 4. Main Control Loop
# ============================================================================
def run_episode(task_description="pick up the object", max_steps=100):
    """Run one episode with VLA control"""
    
    obs, info = env.reset()
    prompt = format_prompt(task_description)
    
    print(f"\nTask: {task_description}")
    print(f"Running for {max_steps} steps...\n")
    
    for step in range(max_steps):
        # Get current observation
        image, state_7d = get_observation(env)
        
        # Prepare inputs for VLA
        inputs = processor(prompt, image).to("cuda:0", dtype=torch.bfloat16)
        
        # Predict action (7-DoF: delta_pos[3] + delta_rot[3] + gripper[1])
        with torch.no_grad():
            action = vla.predict_action(
                **inputs,
                unnorm_key="bridge_orig",
                do_sample=False
            )
        
        # Convert to numpy and ensure correct shape
        action = action.cpu().numpy()
        if action.ndim > 1:
            action = action[0]
        
        # Step environment
        obs, reward, terminated, truncated, info = env.step(action)
        
        # Print progress
        if step % 10 == 0:
            print(f"Step {step}: action = {action}, reward = {reward:.3f}")
        
        # Check if episode is done
        if terminated or truncated:
            print(f"\nEpisode finished at step {step}")
            break
    
    return step