Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@


import os
import shutil
import subprocess
from pathlib import Path

Expand All @@ -24,63 +23,65 @@

class BaseVisionAgent(BaseAgent):
WEIGHTS_URL: str = ""
DEFAULT_WEIGHTS_ROOT_PATH: Path = Path.home() / Path(".cache/rai/")
WEIGHTS_DIR_PATH_PART: Path = Path("vision/weights")
WEIGHTS_FILENAME: str = ""

def __init__(
self,
weights_path: str | Path = Path.home() / Path(".cache/rai/"),
weights_root_path: str | Path = DEFAULT_WEIGHTS_ROOT_PATH,
ros2_name: str = "",
):
if not self.WEIGHTS_FILENAME:
raise ValueError("WEIGHTS_FILENAME is not set")
super().__init__()
self._weights_path = Path(weights_path)
os.makedirs(self._weights_path, exist_ok=True)
self._init_weight_path()
self.weight_path = self._weights_path
self.weights_root_path = Path(weights_root_path)
self.weights_root_path.mkdir(parents=True, exist_ok=True)
self.weights_path = (
self.weights_root_path / self.WEIGHTS_DIR_PATH_PART / self.WEIGHTS_FILENAME
)
if not self.weights_path.exists():
self._download_weights()
self.ros2_connector = ROS2Connector(ros2_name, executor_type="single_threaded")

def _init_weight_path(self):
try:
if self.WEIGHTS_FILENAME == "":
raise ValueError("WEIGHTS_FILENAME is not set")
def _load_model_with_error_handling(self, model_class):
"""Load model with automatic error handling for corrupted weights.

install_path = (
self._weights_path / "vision" / "weights" / self.WEIGHTS_FILENAME
)
# make sure the file exists
if install_path.exists() and install_path.is_file():
self._weights_path = install_path
else:
self._remove_weights(path=install_path)
self._download_weights(install_path)
self._weights_path = install_path
Args:
model_class: A class that can be instantiated with weights_path

except Exception:
self.logger.error("Could not find package path")
raise Exception("Could not find package path")
Returns:
The loaded model instance
"""
try:
return model_class(self.weights_path)
except RuntimeError as e:
self.logger.error(f"Could not load model: {e}")
if "PytorchStreamReader" in str(e):
self.logger.error("The weights might be corrupted. Redownloading...")
self._remove_weights()
self._download_weights()
return model_class(self.weights_path)
else:
raise e

def _download_weights(self, path: Path):
def _download_weights(self):
try:
os.makedirs(path.parent, exist_ok=True)
subprocess.run(
[
"wget",
self.WEIGHTS_URL,
"-O",
path,
self.weights_path,
"--progress=dot:giga",
]
)
except Exception:
self.logger.error("Could not download weights")
raise Exception("Could not download weights")

def _remove_weights(self, path: str):
# Sometimes redownloding weights bugged and created a dir
# so check also for dir and remove it in both cases
if os.path.isdir(path):
shutil.rmtree(path)
elif os.path.isfile(path):
os.remove(path)
def _remove_weights(self):
os.remove(self.weights_path)

def stop(self):
self.ros2_connector.shutdown()
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,11 @@ class GroundedSamAgent(BaseVisionAgent):

def __init__(
self,
weights_path: str | Path = Path.home() / Path(".cache/rai"),
weights_root_path: str | Path = Path.home() / Path(".cache/rai"),
ros2_name: str = GSAM_NODE_NAME,
):
super().__init__(weights_path, ros2_name)
try:
self._segmenter = GDSegmenter(self._weights_path)
except Exception as e:
self.logger.error(
f"Could not load model : {e}. The weights might be corrupted. Redownloading..."
)
self._remove_weights(self.weight_path)
self._init_weight_path()
self.segmenter = GDSegmenter(self.weight_path)
super().__init__(weights_root_path, ros2_name)
self._segmenter = self._load_model_with_error_handling(GDSegmenter)
self.logger.info(f"{self.__class__.__name__} initialized")

def run(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,11 @@ class GroundingDinoAgent(BaseVisionAgent):

def __init__(
self,
weights_path: str | Path = Path.home() / Path(".cache/rai"),
weights_root_path: str | Path = Path.home() / Path(".cache/rai"),
ros2_name: str = GDINO_NODE_NAME,
):
super().__init__(weights_path, ros2_name)
try:
self._boxer = GDBoxer(self._weights_path)
except Exception as e:
self.logger.error(
f"Could not load model: {e}, The weights might be corrupted. Redownloading..."
)
self._remove_weights(self.weight_path)
self._init_weight_path()
self.segmenter = GDBoxer(self.weight_path)
super().__init__(weights_root_path, ros2_name)
self._boxer = self._load_model_with_error_handling(GDBoxer)
self.logger.info(f"{self.__class__.__name__} initialized")

def run(self):
Expand Down