Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: open-set rai refactor #420

Draft
wants to merge 6 commits into
base: refactor/openset-rai2.0
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions examples/agriculture-demo.py
Original file line number Diff line number Diff line change
@@ -15,6 +15,12 @@
import argparse

import rclpy
from rclpy.action import ActionClient
from rclpy.callback_groups import ReentrantCallbackGroup
from rclpy.executors import MultiThreadedExecutor
from rclpy.node import Node
from std_srvs.srv import Trigger

from rai.node import RaiStateBasedLlmNode, describe_ros_image
from rai.tools.ros.native import (
GetCameraImage,
@@ -24,12 +30,6 @@
Ros2ShowMsgInterfaceTool,
)
from rai.tools.time import WaitForSecondsTool
from rclpy.action import ActionClient
from rclpy.callback_groups import ReentrantCallbackGroup
from rclpy.executors import MultiThreadedExecutor
from rclpy.node import Node
from std_srvs.srv import Trigger

from rai_interfaces.action import Task


1 change: 1 addition & 0 deletions examples/manipulation-demo-streamlit.py
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@

import streamlit as st
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage

from rai.agents.integrations.streamlit import get_streamlit_cb, streamlit_invoke
from rai.messages import HumanMultimodalMessage

47 changes: 4 additions & 43 deletions examples/manipulation-demo.launch.py
Original file line number Diff line number Diff line change
@@ -12,22 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import rclpy
from launch import LaunchContext, LaunchDescription
from launch import LaunchDescription
from launch.actions import (
DeclareLaunchArgument,
ExecuteProcess,
IncludeLaunchDescription,
OpaqueFunction,
RegisterEventHandler,
)
from launch.event_handlers import OnExecutionComplete, OnProcessStart
from launch.launch_description_sources import PythonLaunchDescriptionSource
from launch.substitutions import LaunchConfiguration
from launch_ros.actions import Node
from launch_ros.substitutions import FindPackageShare
from rclpy.qos import QoSProfile, ReliabilityPolicy
from rosgraph_msgs.msg import Clock


def generate_launch_description():
@@ -46,21 +40,6 @@ def generate_launch_description():
output="screen",
)

def wait_for_clock_message(context: LaunchContext, *args, **kwargs):
rclpy.init()
node = rclpy.create_node("wait_for_game_launcher")
node.create_subscription(
Clock,
"/clock",
lambda msg: rclpy.shutdown(),
QoSProfile(depth=1, reliability=ReliabilityPolicy.BEST_EFFORT),
)
rclpy.spin(node)
return None

# Game launcher will start publishing the clock message after loading the simulation
wait_for_game_launcher = OpaqueFunction(function=wait_for_clock_message)

launch_moveit = IncludeLaunchDescription(
PythonLaunchDescriptionSource(
[
@@ -90,28 +69,10 @@ def wait_for_clock_message(context: LaunchContext, *args, **kwargs):

return LaunchDescription(
[
# Include the game_launcher argument
game_launcher_arg,
# Launch the game launcher and wait for it to load
launch_game_launcher,
RegisterEventHandler(
event_handler=OnProcessStart(
target_action=launch_game_launcher,
on_start=[
wait_for_game_launcher,
],
)
),
# Launch the MoveIt node after loading the simulation
RegisterEventHandler(
event_handler=OnExecutionComplete(
target_action=wait_for_game_launcher,
on_completion=[
launch_openset,
launch_moveit,
launch_robotic_manipulation,
],
)
),
launch_openset,
launch_moveit,
launch_robotic_manipulation,
]
)
3 changes: 2 additions & 1 deletion examples/manipulation-demo.py
Original file line number Diff line number Diff line change
@@ -16,12 +16,13 @@
import rclpy
import rclpy.qos
from langchain_core.messages import HumanMessage
from rai_open_set_vision.tools import GetGrabbingPointTool

from rai.agents.conversational_agent import create_conversational_agent
from rai.communication.ros2.connectors import ROS2ARIConnector
from rai.tools.ros.manipulation import GetObjectPositionsTool, MoveToPointTool
from rai.tools.ros2.topics import GetROS2ImageTool, GetROS2TopicsNamesAndTypesTool
from rai.utils.model_initialization import get_llm_model
from rai_open_set_vision.tools import GetGrabbingPointTool


def create_agent():
3 changes: 2 additions & 1 deletion examples/rosbot-xl-demo.py
Original file line number Diff line number Diff line change
@@ -18,6 +18,8 @@
import rclpy
import rclpy.executors
import rclpy.logging
from rai_open_set_vision.tools import GetDetectionTool, GetDistanceToObjectsTool

from rai.node import RaiStateBasedLlmNode
from rai.tools.ros.native import (
GetMsgFromTopic,
@@ -33,7 +35,6 @@
Ros2RunActionAsync,
)
from rai.tools.time import WaitForSecondsTool
from rai_open_set_vision.tools import GetDetectionTool, GetDistanceToObjectsTool

p = argparse.ArgumentParser()
p.add_argument("--allowlist", type=Path, required=False, default=None)
4 changes: 2 additions & 2 deletions examples/taxi-demo.py
Original file line number Diff line number Diff line change
@@ -20,12 +20,12 @@
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.tools import tool
from std_msgs.msg import String

from rai.agents.conversational_agent import create_conversational_agent
from rai.tools.ros.cli import Ros2ServiceTool
from rai.tools.ros.native import Ros2PubMessageTool
from rai.utils.model_initialization import get_llm_model, get_tracing_callbacks
from std_msgs.msg import String

from rai_hmi.api import GenericVoiceNode, split_message

system_prompt = """
1 change: 1 addition & 0 deletions src/examples/turtlebot4/turtlebot_demo.py
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@
import rclpy.qos
import rclpy.subscription
import rclpy.task

from rai.node import RaiStateBasedLlmNode
from rai.tools.ros.native import (
GetCameraImage,
6 changes: 3 additions & 3 deletions src/rai_core/rai/communication/ros2/connectors.py
Original file line number Diff line number Diff line change
@@ -205,11 +205,11 @@ def node(self) -> Node:
return self._node

def shutdown(self):
self._executor.shutdown()
self._thread.join()
self._node.destroy_node()
self._actions_api.shutdown()
self._topic_api.shutdown()
self._node.destroy_node()
self._executor.shutdown()
self._thread.join()


class ROS2HRIMessage(HRIMessage):
22 changes: 7 additions & 15 deletions src/rai_core/rai/tools/ros/manipulation.py
Original file line number Diff line number Diff line change
@@ -15,12 +15,6 @@
from typing import Literal, Type

import numpy as np
import rclpy
import rclpy.callback_groups
import rclpy.executors
import rclpy.qos
import rclpy.subscription
import rclpy.task
from geometry_msgs.msg import Point, Pose, PoseStamped, Quaternion
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field
@@ -29,6 +23,7 @@

from rai.communication.ros2.connectors import ROS2ARIConnector
from rai.tools.utils import TF2TransformFetcher
from rai.utils.ros_async import get_future_result
from rai_interfaces.srv import ManipulatorMoveTo


@@ -113,17 +108,14 @@ def _run(
self.connector.node.get_logger().debug(
f"Calling ManipulatorMoveTo service with request: x={request.target_pose.pose.position.x:.2f}, y={request.target_pose.pose.position.y:.2f}, z={request.target_pose.pose.position.z:.2f}"
)
response = get_future_result(future, timeout_sec=5.0)
if response is None:
return f"Service call failed for point ({x:.2f}, {y:.2f}, {z:.2f})."

rclpy.spin_until_future_complete(self.connector.node, future, timeout_sec=5.0)

if future.result() is not None:
response = future.result()
if response.success:
return f"End effector successfully positioned at coordinates ({x:.2f}, {y:.2f}, {z:.2f}). Note: The status of object interaction (grab/drop) is not confirmed by this movement."
else:
return f"Failed to position end effector at coordinates ({x:.2f}, {y:.2f}, {z:.2f})."
if response.success:
return f"End effector successfully positioned at coordinates ({x:.2f}, {y:.2f}, {z:.2f}). Note: The status of object interaction (grab/drop) is not confirmed by this movement."
else:
return f"Service call failed for point ({x:.2f}, {y:.2f}, {z:.2f})."
return f"Failed to position end effector at coordinates ({x:.2f}, {y:.2f}, {z:.2f})."


class GetObjectPositionsToolInput(BaseModel):
4 changes: 3 additions & 1 deletion src/rai_core/rai/tools/ros/utils.py
Original file line number Diff line number Diff line change
@@ -151,7 +151,9 @@ def wait_for_message(
if msg_info is not None:
return True, msg_info[0]
finally:
node.destroy_subscription(sub)
# TODO(boczekbartek): uncomment when rclpy resolves: https://github.com/ros2/rclpy/issues/1142
# node.destroy_subscription(sub)
pass

return False, None

Original file line number Diff line number Diff line change
@@ -17,16 +17,16 @@
import numpy as np
import sensor_msgs.msg
from pydantic import BaseModel, Field
from rai.communication.ros2.connectors import ROS2ARIConnector
from rai.tools.ros import Ros2BaseInput, Ros2BaseTool
from rai.tools.ros.utils import convert_ros_img_to_ndarray
from rai.utils.ros_async import get_future_result
from rclpy.exceptions import (
ParameterNotDeclaredException,
ParameterUninitializedException,
)
from rclpy.task import Future

from rai.communication.ros2.connectors import ROS2ARIConnector
from rai.tools.ros import Ros2BaseInput, Ros2BaseTool
from rai.tools.ros.utils import convert_ros_img_to_ndarray
from rai.utils.ros_async import get_future_result
from rai_interfaces.srv import RAIGroundingDino
from rai_open_set_vision import GDINO_SERVICE_NAME

Original file line number Diff line number Diff line change
@@ -20,16 +20,16 @@
import sensor_msgs.msg
from langchain_core.tools import BaseTool
from pydantic import Field
from rai.communication.ros2.connectors import ROS2ARIConnector
from rai.tools.ros import Ros2BaseInput
from rai.tools.ros.utils import convert_ros_img_to_base64, convert_ros_img_to_ndarray
from rai.utils.ros_async import get_future_result
from rclpy import Future
from rclpy.exceptions import (
ParameterNotDeclaredException,
ParameterUninitializedException,
)

from rai.communication.ros2.connectors import ROS2ARIConnector
from rai.tools.ros import Ros2BaseInput
from rai.tools.ros.utils import convert_ros_img_to_base64, convert_ros_img_to_ndarray
from rai.utils.ros_async import get_future_result
from rai_interfaces.srv import RAIGroundedSam, RAIGroundingDino
from rai_open_set_vision import GDINO_SERVICE_NAME

@@ -342,21 +342,17 @@ def _run(
)
conversion_ratio = 0.001
resolved = None
while rclpy.ok():
resolved = self._get_gdino_response(future)
if resolved is not None:
break

resolved = get_future_result(future)

assert resolved is not None
future = self._call_gsam_node(camera_img_msg, resolved)

ret = []
while rclpy.ok():
resolved = self._get_gsam_response(future)
if resolved is not None:
for img_msg in resolved.masks:
ret.append(convert_ros_img_to_base64(img_msg))
break
resolved = get_future_result(future)
if resolved is not None:
for img_msg in resolved.masks:
ret.append(convert_ros_img_to_base64(img_msg))
assert resolved is not None
rets = []
for mask_msg in resolved.masks:
Original file line number Diff line number Diff line change
@@ -20,12 +20,13 @@
import numpy as np
import torch
from cv_bridge import CvBridge
from rai.tools.ros.utils import convert_ros_img_to_ndarray
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sensor_msgs.msg import Image
from vision_msgs.msg import BoundingBox2D

from rai.tools.ros.utils import convert_ros_img_to_ndarray


class GDSegmenter:
def __init__(
2 changes: 1 addition & 1 deletion src/rai_hmi/rai_hmi/agent.py
Original file line number Diff line number Diff line change
@@ -17,11 +17,11 @@
from typing import List

from langchain.tools import tool

from rai.agents.conversational_agent import create_conversational_agent
from rai.node import RaiBaseNode
from rai.tools.ros.native import GetCameraImage, Ros2GetRobotInterfaces
from rai.utils.model_initialization import get_llm_model

from rai_hmi.base import BaseHMINode
from rai_hmi.chat_msgs import MissionMessage
from rai_hmi.task import Task, TaskInput
4 changes: 2 additions & 2 deletions src/rai_hmi/rai_hmi/base.py
Original file line number Diff line number Diff line change
@@ -22,13 +22,13 @@
from langchain_core.documents import Document
from langchain_core.tools import BaseTool
from pydantic import UUID4
from rai.node import append_whoami_info_to_prompt
from rai.utils.model_initialization import get_embeddings_model
from rclpy.action import ActionClient
from rclpy.node import Node
from std_msgs.msg import String
from std_srvs.srv import Trigger

from rai.node import append_whoami_info_to_prompt
from rai.utils.model_initialization import get_embeddings_model
from rai_hmi.chat_msgs import (
MissionAcceptanceMessage,
MissionDoneMessage,
2 changes: 1 addition & 1 deletion src/rai_hmi/rai_hmi/ros.py
Original file line number Diff line number Diff line change
@@ -19,9 +19,9 @@
from typing import Optional, Tuple

import rclpy
from rai.node import RaiBaseNode
from rclpy.executors import MultiThreadedExecutor

from rai.node import RaiBaseNode
from rai_hmi.base import BaseHMINode


6 changes: 3 additions & 3 deletions src/rai_hmi/rai_hmi/text_hmi.py
Original file line number Diff line number Diff line change
@@ -33,12 +33,12 @@
)
from PIL import Image
from pydantic import BaseModel
from rai.messages import HumanMultimodalMessage
from rai.node import RaiBaseNode
from rai.utils.artifacts import get_stored_artifacts
from rclpy.node import Node
from streamlit.delta_generator import DeltaGenerator

from rai.messages import HumanMultimodalMessage
from rai.node import RaiBaseNode
from rai.utils.artifacts import get_stored_artifacts
from rai_hmi.agent import initialize_agent
from rai_hmi.base import BaseHMINode
from rai_hmi.chat_msgs import EMOJIS, MissionMessage
2 changes: 1 addition & 1 deletion src/rai_hmi/rai_hmi/voice_hmi.py
Original file line number Diff line number Diff line change
@@ -22,12 +22,12 @@

import rclpy
from langchain_core.messages import HumanMessage
from rai.node import RaiBaseNode
from rclpy.callback_groups import ReentrantCallbackGroup
from rclpy.executors import MultiThreadedExecutor
from rclpy.qos import DurabilityPolicy, HistoryPolicy, QoSProfile, ReliabilityPolicy
from std_msgs.msg import String

from rai.node import RaiBaseNode
from rai_hmi.agent import initialize_agent
from rai_hmi.base import BaseHMINode
from rai_hmi.text_hmi_utils import Memory
Loading
Oops, something went wrong.