From 008be90a23a3c889f616d79d148939d4c55b883b Mon Sep 17 00:00:00 2001 From: Sai Nivedh Date: Tue, 23 Apr 2024 20:26:55 +0530 Subject: [PATCH] add flag to download exported model (#337) --- clarifai/client/model.py | 69 ++++++++++++++++++++++--------------- clarifai/constants/model.py | 1 + 2 files changed, 43 insertions(+), 27 deletions(-) diff --git a/clarifai/client/model.py b/clarifai/client/model.py index f9469229..3b510a1a 100644 --- a/clarifai/client/model.py +++ b/clarifai/client/model.py @@ -17,7 +17,8 @@ from clarifai.client.dataset import Dataset from clarifai.client.input import Inputs from clarifai.client.lister import Lister -from clarifai.constants.model import MAX_MODEL_PREDICT_INPUTS, TRAINABLE_MODEL_TYPES +from clarifai.constants.model import (MAX_MODEL_PREDICT_INPUTS, MODEL_EXPORT_TIMEOUT, + TRAINABLE_MODEL_TYPES) from clarifai.errors import UserError from clarifai.urls.helper import ClarifaiUrlHelper from clarifai.utils.logging import get_logger @@ -949,19 +950,22 @@ def export(self, export_dir: str = None) -> None: """Export the model, stores the exported model as model.tar file Args: - export_dir (str): The directory to save the exported model. + export_dir (str, optional): If provided, the exported model will be saved in the specified directory else export status will be shown. Defaults to None. Example: >>> from clarifai.client.model import Model >>> model = Model("url") + >>> model.export() + or >>> model.export('/path/to/export_model_dir') """ assert self.model_info.model_version.id, "Model version ID is missing. Please provide a `model_version` with a valid `id` as an argument or as a URL in the following format: '{user_id}/{app_id}/models/{your_model_id}/model_version_id/{your_version_model_id}' when initializing." - try: - if not os.path.exists(export_dir): - os.makedirs(export_dir) - except OSError as e: - raise Exception(f"An error occurred while creating the directory: {e}") + if export_dir: + try: + if not os.path.exists(export_dir): + os.makedirs(export_dir) + except OSError as e: + raise Exception(f"An error occurred while creating the directory: {e}") def _get_export_response(): get_export_request = service_pb2.GetModelVersionExportRequest( @@ -1010,28 +1014,39 @@ def _download_exported_model( raise Exception(response.status) self.logger.info( - f"Model ID {self.id} with version {self.model_info.model_version.id} export started, please wait..." + f"Export process has started for Model ID {self.id}, Version {self.model_info.model_version.id}" ) - time.sleep(5) - start_time = time.time() - backoff_iterator = BackoffIterator(10) - while True: - get_export_response = _get_export_response() - if get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTING and \ - time.time() - start_time < 60 * 30: # 30 minutes - self.logger.info( - f"Model ID {self.id} with version {self.model_info.model_version.id} is still exporting, please wait..." - ) - time.sleep(next(backoff_iterator)) - elif get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTED: - _download_exported_model(get_export_response, os.path.join(export_dir, "model.tar")) - break - elif time.time() - start_time > 60 * 30: - raise Exception( - f"""Model Export took too long. Please try again or contact support@clarifai.com - Req ID: {get_export_response.status.req_id}""") + if export_dir: + start_time = time.time() + backoff_iterator = BackoffIterator(10) + while True: + get_export_response = _get_export_response() + if (get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTING or \ + get_export_response.export.status.code == status_code_pb2.MODEL_EXPORT_PENDING) and \ + time.time() - start_time < MODEL_EXPORT_TIMEOUT: + self.logger.info( + f"Export process is ongoing for Model ID {self.id}, Version {self.model_info.model_version.id}. Please wait..." + ) + time.sleep(next(backoff_iterator)) + elif get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTED: + _download_exported_model(get_export_response, os.path.join(export_dir, "model.tar")) + break + elif time.time() - start_time > MODEL_EXPORT_TIMEOUT: + raise Exception( + f"""Model Export took too long. Please try again or contact support@clarifai.com + Req ID: {get_export_response.status.req_id}""") elif get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTED: - _download_exported_model(get_export_response, os.path.join(export_dir, "model.tar")) + if export_dir: + _download_exported_model(get_export_response, os.path.join(export_dir, "model.tar")) + else: + self.logger.info( + f"Model ID {self.id} with version {self.model_info.model_version.id} is already exported, you can download it from the following URL: {get_export_response.export.url}" + ) + elif get_export_response.export.status.code == status_code_pb2.MODEL_EXPORTING or \ + get_export_response.export.status.code == status_code_pb2.MODEL_EXPORT_PENDING: + self.logger.info( + f"Export process is ongoing for Model ID {self.id}, Version {self.model_info.model_version.id}. Please wait..." + ) @staticmethod def _make_pretrained_config_proto(input_field_maps: dict, diff --git a/clarifai/constants/model.py b/clarifai/constants/model.py index 8a87436f..968a1ce2 100644 --- a/clarifai/constants/model.py +++ b/clarifai/constants/model.py @@ -3,3 +3,4 @@ 'text-classifier', 'embedding-classifier', 'text-to-text' ] MAX_MODEL_PREDICT_INPUTS = 128 +MODEL_EXPORT_TIMEOUT = 1800