From 4666d78cf6382b4c0ce3eb90364e8778dc9067d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Kotowski?= Date: Fri, 12 Sep 2025 18:32:35 +0200 Subject: [PATCH 1/3] fix: redownload weights only on PytorchStreamReader error --- .../agents/base_vision_agent.py | 67 ++++++++++--------- .../agents/grounded_sam.py | 10 +-- .../agents/grounding_dino.py | 14 +--- 3 files changed, 38 insertions(+), 53 deletions(-) diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/agents/base_vision_agent.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/agents/base_vision_agent.py index c143b4d13..08087b018 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/agents/base_vision_agent.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/agents/base_vision_agent.py @@ -14,7 +14,6 @@ import os -import shutil import subprocess from pathlib import Path @@ -24,49 +23,56 @@ class BaseVisionAgent(BaseAgent): WEIGHTS_URL: str = "" + DEFAULT_WEIGHTS_ROOT_PATH: Path = Path.home() / Path(".cache/rai/") + WEIGHTS_DIR_PATH_PART: Path = Path("vision/weights") WEIGHTS_FILENAME: str = "" def __init__( self, - weights_path: str | Path = Path.home() / Path(".cache/rai/"), + weights_root_path: str | Path = DEFAULT_WEIGHTS_ROOT_PATH, ros2_name: str = "", ): + if not self.WEIGHTS_FILENAME: + raise ValueError("WEIGHTS_FILENAME is not set") super().__init__() - self._weights_path = Path(weights_path) - os.makedirs(self._weights_path, exist_ok=True) - self._init_weight_path() - self.weight_path = self._weights_path + self.weights_root_path = Path(weights_root_path) + self.weights_root_path.mkdir(parents=True, exist_ok=True) + self.weights_path = ( + self.weights_root_path / self.WEIGHTS_DIR_PATH_PART / self.WEIGHTS_FILENAME + ) + if not self.weights_path.exists(): + self.download_weights() self.ros2_connector = ROS2Connector(ros2_name, executor_type="single_threaded") - def _init_weight_path(self): - try: - if self.WEIGHTS_FILENAME == "": - raise ValueError("WEIGHTS_FILENAME is not set") + def _load_model_with_error_handling(self, model_class): + """Load model with automatic error handling for corrupted weights. - install_path = ( - self._weights_path / "vision" / "weights" / self.WEIGHTS_FILENAME - ) - # make sure the file exists - if install_path.exists() and install_path.is_file(): - self._weights_path = install_path - else: - self._remove_weights(path=install_path) - self._download_weights(install_path) - self._weights_path = install_path + Args: + model_class: A class that can be instantiated with weights_path - except Exception: - self.logger.error("Could not find package path") - raise Exception("Could not find package path") + Returns: + The loaded model instance + """ + try: + return model_class(self.weights_path) + except RuntimeError as e: + self.logger.error(f"Could not load model: {e}") + if "PytorchStreamReader" in str(e): + self.logger.error("The weights might be corrupted. Redownloading...") + self.remove_weights() + self.download_weights() + return model_class(self.weights_path) + else: + raise e - def _download_weights(self, path: Path): + def download_weights(self): try: - os.makedirs(path.parent, exist_ok=True) subprocess.run( [ "wget", self.WEIGHTS_URL, "-O", - path, + self.weights_path, "--progress=dot:giga", ] ) @@ -74,13 +80,8 @@ def _download_weights(self, path: Path): self.logger.error("Could not download weights") raise Exception("Could not download weights") - def _remove_weights(self, path: str): - # Sometimes redownloding weights bugged and created a dir - # so check also for dir and remove it in both cases - if os.path.isdir(path): - shutil.rmtree(path) - elif os.path.isfile(path): - os.remove(path) + def remove_weights(self): + os.remove(self.weights_path) def stop(self): self.ros2_connector.shutdown() diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/agents/grounded_sam.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/agents/grounded_sam.py index be9ded1d0..5f7e57f3d 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/agents/grounded_sam.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/agents/grounded_sam.py @@ -38,15 +38,7 @@ def __init__( ros2_name: str = GSAM_NODE_NAME, ): super().__init__(weights_path, ros2_name) - try: - self._segmenter = GDSegmenter(self._weights_path) - except Exception as e: - self.logger.error( - f"Could not load model : {e}. The weights might be corrupted. Redownloading..." - ) - self._remove_weights(self.weight_path) - self._init_weight_path() - self.segmenter = GDSegmenter(self.weight_path) + self._segmenter = self._load_model_with_error_handling(GDSegmenter) self.logger.info(f"{self.__class__.__name__} initialized") def run(self): diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/agents/grounding_dino.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/agents/grounding_dino.py index b4f138958..bb01ee4a9 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/agents/grounding_dino.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/agents/grounding_dino.py @@ -29,19 +29,11 @@ class GroundingDinoAgent(BaseVisionAgent): def __init__( self, - weights_path: str | Path = Path.home() / Path(".cache/rai"), + weights_root_path: str | Path = Path.home() / Path(".cache/rai"), ros2_name: str = GDINO_NODE_NAME, ): - super().__init__(weights_path, ros2_name) - try: - self._boxer = GDBoxer(self._weights_path) - except Exception as e: - self.logger.error( - f"Could not load model: {e}, The weights might be corrupted. Redownloading..." - ) - self._remove_weights(self.weight_path) - self._init_weight_path() - self.segmenter = GDBoxer(self.weight_path) + super().__init__(weights_root_path, ros2_name) + self._boxer = self._load_model_with_error_handling(GDBoxer) self.logger.info(f"{self.__class__.__name__} initialized") def run(self): From 7dc6ba1525a7f9a26a066f87d4793698eca4ab7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Kotowski?= Date: Wed, 17 Sep 2025 16:51:08 +0200 Subject: [PATCH 2/3] rename arg --- .../rai_open_set_vision/agents/grounded_sam.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/agents/grounded_sam.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/agents/grounded_sam.py index 5f7e57f3d..62bc96deb 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/agents/grounded_sam.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/agents/grounded_sam.py @@ -34,10 +34,10 @@ class GroundedSamAgent(BaseVisionAgent): def __init__( self, - weights_path: str | Path = Path.home() / Path(".cache/rai"), + weights_root_path: str | Path = Path.home() / Path(".cache/rai"), ros2_name: str = GSAM_NODE_NAME, ): - super().__init__(weights_path, ros2_name) + super().__init__(weights_root_path, ros2_name) self._segmenter = self._load_model_with_error_handling(GDSegmenter) self.logger.info(f"{self.__class__.__name__} initialized") From fa8a1538d5c1caa31f7236b2de3074ceccd18c73 Mon Sep 17 00:00:00 2001 From: maciejmajek Date: Wed, 1 Oct 2025 21:27:15 +0200 Subject: [PATCH 3/3] style: make remove and download weights private --- .../rai_open_set_vision/agents/base_vision_agent.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/agents/base_vision_agent.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/agents/base_vision_agent.py index 08087b018..851e74b24 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/agents/base_vision_agent.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/agents/base_vision_agent.py @@ -41,7 +41,7 @@ def __init__( self.weights_root_path / self.WEIGHTS_DIR_PATH_PART / self.WEIGHTS_FILENAME ) if not self.weights_path.exists(): - self.download_weights() + self._download_weights() self.ros2_connector = ROS2Connector(ros2_name, executor_type="single_threaded") def _load_model_with_error_handling(self, model_class): @@ -59,13 +59,13 @@ def _load_model_with_error_handling(self, model_class): self.logger.error(f"Could not load model: {e}") if "PytorchStreamReader" in str(e): self.logger.error("The weights might be corrupted. Redownloading...") - self.remove_weights() - self.download_weights() + self._remove_weights() + self._download_weights() return model_class(self.weights_path) else: raise e - def download_weights(self): + def _download_weights(self): try: subprocess.run( [ @@ -80,7 +80,7 @@ def download_weights(self): self.logger.error("Could not download weights") raise Exception("Could not download weights") - def remove_weights(self): + def _remove_weights(self): os.remove(self.weights_path) def stop(self):