diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index faf380eb..9f1ac34a 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -9,29 +9,79 @@ on: workflow_dispatch: jobs: - setup-and-test: + test: runs-on: ubuntu-latest + timeout-minutes: 45 # Global timeout fallback strategy: fail-fast: false matrix: test-suite: [ - 'tests/unit', - 'tests/functional/file_asset', - 'tests/functional/data_asset', - 'tests/functional/benchmark', - 'tests/functional/model', - 'tests/functional/pipelines/run_test.py --pipeline_version 2.0 --sdk_version v1 --sdk_version_param PipelineFactory', - 'tests/functional/pipelines/run_test.py --pipeline_version 2.0 --sdk_version v2 --sdk_version_param PipelineFactory', - 'tests/functional/pipelines/run_test.py --pipeline_version 3.0 --sdk_version v1 --sdk_version_param PipelineFactory', - 'tests/functional/pipelines/run_test.py --pipeline_version 3.0 --sdk_version v2 --sdk_version_param PipelineFactory', - 'tests/functional/pipelines/designer_test.py', - 'tests/functional/pipelines/create_test.py', - 'tests/functional/finetune --sdk_version v1 --sdk_version_param FinetuneFactory', - 'tests/functional/finetune --sdk_version v2 --sdk_version_param FinetuneFactory', - 'tests/functional/general_assets', - 'tests/functional/apikey', - 'tests/functional/agent tests/functional/team_agent', + 'unit', + 'file_asset', + 'data_asset', + 'benchmark', + 'model', + 'pipeline_2.0_v1', + 'pipeline_2.0_v2', + 'pipeline_3.0_v1', + 'pipeline_3.0_v2', + 'pipeline_designer', + 'pipeline_create', + 'finetune_v1', + 'finetune_v2', + 'general_assets', + 'apikey', + 'agent_and_team_agent', ] + include: + - test-suite: 'unit' + path: 'tests/unit' + timeout: 45 # Tweaked timeout for each unit tests + - test-suite: 'file_asset' + path: 'tests/functional/file_asset' + timeout: 45 + - test-suite: 'data_asset' + path: 'tests/functional/data_asset' + timeout: 45 + - test-suite: 'benchmark' + path: 'tests/functional/benchmark' + timeout: 45 + - test-suite: 'model' + path: 'tests/functional/model' + timeout: 45 + - test-suite: 'pipeline_2.0_v1' + path: 'tests/functional/pipelines/run_test.py --pipeline_version 2.0 --sdk_version v1 --sdk_version_param PipelineFactory' + timeout: 45 + - test-suite: 'pipeline_2.0_v2' + path: 'tests/functional/pipelines/run_test.py --pipeline_version 2.0 --sdk_version v2 --sdk_version_param PipelineFactory' + timeout: 45 + - test-suite: 'pipeline_3.0_v1' + path: 'tests/functional/pipelines/run_test.py --pipeline_version 3.0 --sdk_version v1 --sdk_version_param PipelineFactory' + timeout: 45 + - test-suite: 'pipeline_3.0_v2' + path: 'tests/functional/pipelines/run_test.py --pipeline_version 3.0 --sdk_version v2 --sdk_version_param PipelineFactory' + timeout: 45 + - test-suite: 'pipeline_designer' + path: 'tests/functional/pipelines/designer_test.py' + timeout: 45 + - test-suite: 'pipeline_create' + path: 'tests/functional/pipelines/create_test.py' + timeout: 45 + - test-suite: 'finetune_v1' + path: 'tests/functional/finetune --sdk_version v1 --sdk_version_param FinetuneFactory' + timeout: 45 + - test-suite: 'finetune_v2' + path: 'tests/functional/finetune --sdk_version v2 --sdk_version_param FinetuneFactory' + timeout: 45 + - test-suite: 'general_assets' + path: 'tests/functional/general_assets' + timeout: 45 + - test-suite: 'apikey' + path: 'tests/functional/apikey' + timeout: 45 + - test-suite: 'agent_and_team_agent' + path: 'tests/functional/agent tests/functional/team_agent' + timeout: 45 steps: - name: Checkout repository uses: actions/checkout@v4 @@ -62,4 +112,5 @@ jobs: fi - name: Run Tests - run: python -m pytest ${{ matrix.test-suite}} \ No newline at end of file + timeout-minutes: ${{ matrix.timeout }} + run: python -m pytest ${{ matrix.path }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 456aba3b..2500c2c8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,4 +22,4 @@ repos: hooks: - id: flake8 args: # arguments to configure flake8 - - --ignore=E402,E501,E203 \ No newline at end of file + - --ignore=E402,E501,E203,W503 \ No newline at end of file diff --git a/aixplain/enums/asset_status.py b/aixplain/enums/asset_status.py index 994212fb..3d1e4323 100644 --- a/aixplain/enums/asset_status.py +++ b/aixplain/enums/asset_status.py @@ -43,3 +43,4 @@ class AssetStatus(Text, Enum): COMPLETED = "completed" CANCELING = "canceling" CANCELED = "canceled" + DEPRECATED_DRAFT = "deprecated_draft" diff --git a/aixplain/enums/embedding_model.py b/aixplain/enums/embedding_model.py index c52387b2..31618580 100644 --- a/aixplain/enums/embedding_model.py +++ b/aixplain/enums/embedding_model.py @@ -27,9 +27,7 @@ class EmbeddingModel(str, Enum): JINA_CLIP_V2_MULTIMODAL = "67c5f705d8f6a65d6f74d732" MULTILINGUAL_E5_LARGE = "67efd0772a0a850afa045af3" BGE_M3 = "67f401032a0a850afa045b19" - - - + AIXPLAIN_LEGAL_EMBEDDINGS = "681254b668e47e7844c1f15a" def __str__(self): return self._value_ diff --git a/aixplain/enums/license.py b/aixplain/enums/license.py index a860a539..566be092 100644 --- a/aixplain/enums/license.py +++ b/aixplain/enums/license.py @@ -25,8 +25,8 @@ from enum import Enum from urllib.parse import urljoin from aixplain.utils import config -from aixplain.utils.request_utils import _request_with_retry from aixplain.utils.cache_utils import save_to_cache, load_from_cache, CACHE_FOLDER +from aixplain.utils.request_utils import _request_with_retry CACHE_FILE = f"{CACHE_FOLDER}/licenses.json" diff --git a/aixplain/enums/supplier.py b/aixplain/enums/supplier.py index 26058bf5..18a3e81d 100644 --- a/aixplain/enums/supplier.py +++ b/aixplain/enums/supplier.py @@ -24,7 +24,7 @@ import logging from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from enum import Enum from urllib.parse import urljoin import re diff --git a/aixplain/exceptions/__init__.py b/aixplain/exceptions/__init__.py new file mode 100644 index 00000000..5d645c62 --- /dev/null +++ b/aixplain/exceptions/__init__.py @@ -0,0 +1,116 @@ +""" +Error message registry for aiXplain SDK. + +This module maintains a centralized registry of error messages used throughout the aiXplain ecosystem. +It allows developers to look up existing error messages and reuse them instead of creating new ones. +""" + +from aixplain.exceptions.types import ( + AixplainBaseException, + AuthenticationError, + ValidationError, + ResourceError, + BillingError, + SupplierError, + NetworkError, + ServiceError, + InternalError, +) + + +def get_error_from_status_code(status_code: int, error_details: str = None) -> AixplainBaseException: + """ + Map HTTP status codes to appropriate exception types. + + Args: + status_code (int): The HTTP status code to map. + default_message (str, optional): The default message to use if no specific message is available. + + Returns: + AixplainBaseException: An exception of the appropriate type. + """ + try: + if isinstance(status_code, str): + status_code = int(status_code) + except Exception as e: + raise InternalError(f"Failed to get status code from {status_code}: {e}") from e + + error_details = f"Details: {error_details}" if error_details else "" + if status_code == 400: + return ValidationError( + message=f"Bad request: Please verify the request payload and ensure it is correct. {error_details}".strip(), + status_code=status_code, + ) + elif status_code == 401: + return AuthenticationError( + message=f"Unauthorized API key: Please verify the spelling of the API key and its current validity. {error_details}".strip(), + status_code=status_code, + ) + elif status_code == 402: + return BillingError( + message=f"Payment required: Please ensure you have enough credits to run this asset. {error_details}".strip(), + status_code=status_code, + ) + elif status_code == 403: + # 403 could be auth or resource, using ResourceError as a general 'forbidden' + return ResourceError( + message=f"Forbidden access: Please verify the API key and its current validity. {error_details}".strip(), + status_code=status_code, + ) + elif status_code == 404: + # Added 404 mapping + return ResourceError( + message=f"Resource not found: Please verify the spelling of the resource and its current availability. {error_details}".strip(), + status_code=status_code, + ) + elif status_code == 429: + # Using SupplierError for rate limiting as per your original function + return SupplierError( + message=f"Rate limit exceeded: Please try again later. {error_details}".strip(), + status_code=status_code, + retry_recommended=True, + ) + elif status_code == 500: + return InternalError( + message=f"Internal server error: Please try again later. {error_details}".strip(), + status_code=status_code, + retry_recommended=True, + ) + elif status_code == 503: + return ServiceError( + message=f"Service unavailable: Please try again later. {error_details}".strip(), + status_code=status_code, + retry_recommended=True, + ) + elif status_code == 504: + return NetworkError( + message=f"Gateway timeout: Please try again later. {error_details}".strip(), + status_code=status_code, + retry_recommended=True, + ) + elif 460 <= status_code < 470: + return ResourceError( + message=f"Subscription-related error: Please ensure that your subscription is active and has not expired. {error_details}".strip(), + status_code=status_code, + ) + elif 470 <= status_code < 480: + return BillingError( + message=f"Billing-related error: Please ensure you have enough credits to run this asset. {error_details}".strip(), + status_code=status_code, + ) + elif 480 <= status_code < 490: + return SupplierError( + message=f"Supplier-related error: Please ensure that the selected supplier provides the asset you are trying to access. {error_details}".strip(), + status_code=status_code, + ) + elif 490 <= status_code < 500: + return ValidationError( + message=f"Validation-related error: Please verify the request payload and ensure it is correct. {error_details}".strip(), + status_code=status_code, + ) + else: + # Catch-all for other client/server errors + category = "Client" if 400 <= status_code < 500 else "Server" + return InternalError( + message=f"Unspecified {category} Error (Status {status_code}) {error_details}".strip(), status_code=status_code + ) diff --git a/aixplain/exceptions/types.py b/aixplain/exceptions/types.py new file mode 100644 index 00000000..e41c7fd2 --- /dev/null +++ b/aixplain/exceptions/types.py @@ -0,0 +1,217 @@ +from enum import Enum +from typing import Optional, Dict, Any + + +class ErrorSeverity(str, Enum): + """Severity levels for errors.""" + + INFO = "info" # Informational, not an error + WARNING = "warning" # Warning, operation can continue + ERROR = "error" # Error, operation cannot continue + CRITICAL = "critical" # System stability might be compromised + + +class ErrorCategory(Enum): + """Categorizes errors by their domain.""" + + AUTHENTICATION = "authentication" # API keys, permissions + VALIDATION = "validation" # Input validation + RESOURCE = "resource" # Resource availability + BILLING = "billing" # Credits, payment + SUPPLIER = "supplier" # External supplier issues + NETWORK = "network" # Network connectivity + SERVICE = "service" # Service availability + INTERNAL = "internal" # Internal system errors + AGENT = "agent" # Agent-specific errors + UNKNOWN = "unknown" # Uncategorized errors + + +class ErrorCode(str, Enum): + """Standard error codes for aiXplain exceptions. + + The format is AX--, where is a short identifier + derived from the ErrorCategory (e.g., AUTH, VAL, RES) and is a + unique sequential number within that category, starting from 1000. + + How to Add a New Error Code: + 1. Identify the appropriate `ErrorCategory` for the new error. + 2. Determine the next available sequential ID within that category. + For example, if `AX-AUTH-1000` exists, the next authentication-specific + error could be `AX-AUTH-1001`. + 3. Define the new enum member using the format `AX--`. + Use a concise abbreviation for the category (e.g., AUTH, VAL, RES, BIL, + SUP, NET, SVC, INT). + 4. Assign the string value (e.g., `"AX-AUTH-1001"`). + 5. Add a clear docstring explaining the specific condition that triggers + this error code. + 6. (Optional but recommended) Consider creating a more specific exception + class inheriting from the corresponding category exception (e.g., + `class InvalidApiKeyError(AuthenticationError): ...`) and assign the + new error code to it. + """ + + AX_AUTH_ERROR = "AX-AUTH-1000" # General authentication error. Use for issues like invalid API keys, insufficient permissions, or failed login attempts. + AX_VAL_ERROR = "AX-VAL-1000" # General validation error. Use when user-provided input fails validation checks (e.g., incorrect data type, missing required fields, invalid format. + AX_RES_ERROR = "AX-RES-1000" # General resource error. Use for issues related to accessing or managing resources, such as a requested model being unavailable or quota limits exceeded. + AX_BIL_ERROR = "AX-BIL-1000" # General billing error. Use for problems related to billing, payments, or credits (e.g., insufficient funds, expired subscription. + AX_SUP_ERROR = "AX-SUP-1000" # General supplier error. Use when an error originates from an external supplier or third-party service integrated with aiXplain. + AX_NET_ERROR = "AX-NET-1000" # General network error. Use for issues related to network connectivity, such as timeouts, DNS resolution failures, or unreachable services. + AX_SVC_ERROR = "AX-SVC-1000" # General service error. Use when a specific aiXplain service or endpoint is unavailable or malfunctioning (e.g., service downtime, internal component failure. + AX_INT_ERROR = "AX-INT-1000" # General internal error. Use for unexpected server-side errors that are not covered by other categories. This often indicates a bug or an issue within the aiXplain platform itself. + + +class AixplainBaseException(Exception): + """Base exception class for all aiXplain exceptions.""" + + def __init__( + self, + message: str, + category: ErrorCategory = ErrorCategory.UNKNOWN, + severity: ErrorSeverity = ErrorSeverity.ERROR, + status_code: Optional[int] = None, + details: Optional[Dict[str, Any]] = None, + retry_recommended: bool = False, + error_code: Optional[ErrorCode] = None, + ): + self.message = message + self.category = category + self.severity = severity + self.status_code = status_code + self.details = details or {} + self.retry_recommended = retry_recommended + self.error_code = error_code + super().__init__(self.message) + + def __str__(self): + error_code_str = f" [{self.error_code}]" if self.error_code else "" + return f"{self.__class__.__name__}{error_code_str}: {self.message}" + + def to_dict(self) -> Dict[str, Any]: + """Convert exception to dictionary for serialization.""" + return { + "message": self.message, + "category": self.category.value, + "severity": self.severity.value, + "status_code": self.status_code, + "details": self.details, + "retry_recommended": self.retry_recommended, + "error_code": self.error_code.value if self.error_code else None, + } + + +class AuthenticationError(AixplainBaseException): + """Raised when authentication fails.""" + + def __init__(self, message: str, **kwargs): + super().__init__( + message=message, + category=ErrorCategory.AUTHENTICATION, + severity=ErrorSeverity.ERROR, + retry_recommended=kwargs.pop("retry_recommended", False), + error_code=ErrorCode.AX_AUTH_ERROR, + **kwargs, + ) + + +class ValidationError(AixplainBaseException): + """Raised when input validation fails.""" + + def __init__(self, message: str, **kwargs): + super().__init__( + message=message, + category=ErrorCategory.VALIDATION, + severity=ErrorSeverity.ERROR, + retry_recommended=kwargs.pop("retry_recommended", False), + error_code=ErrorCode.AX_VAL_ERROR, + **kwargs, + ) + + +class ResourceError(AixplainBaseException): + """Raised when a resource is unavailable.""" + + def __init__(self, message: str, **kwargs): + super().__init__( + message=message, + category=ErrorCategory.RESOURCE, + severity=ErrorSeverity.ERROR, + retry_recommended=kwargs.pop("retry_recommended", False), + error_code=ErrorCode.AX_RES_ERROR, + **kwargs, + ) + + +class BillingError(AixplainBaseException): + """Raised when there are billing issues.""" + + def __init__(self, message: str, **kwargs): + super().__init__( + message=message, + category=ErrorCategory.BILLING, + severity=ErrorSeverity.ERROR, + retry_recommended=kwargs.pop("retry_recommended", False), + error_code=ErrorCode.AX_BIL_ERROR, + **kwargs, + ) + + +class SupplierError(AixplainBaseException): + """Raised when there are issues with external suppliers.""" + + def __init__(self, message: str, **kwargs): + super().__init__( + message=message, + category=ErrorCategory.SUPPLIER, + severity=ErrorSeverity.ERROR, + retry_recommended=kwargs.pop("retry_recommended", True), + error_code=ErrorCode.AX_SUP_ERROR, + **kwargs, + ) + + +class NetworkError(AixplainBaseException): + """Raised when there are network connectivity issues.""" + + def __init__(self, message: str, **kwargs): + super().__init__( + message=message, + category=ErrorCategory.NETWORK, + severity=ErrorSeverity.ERROR, + retry_recommended=kwargs.pop("retry_recommended", True), + error_code=ErrorCode.AX_NET_ERROR, + **kwargs, + ) + + +class ServiceError(AixplainBaseException): + """Raised when a service is unavailable.""" + + def __init__(self, message: str, **kwargs): + super().__init__( + message=message, + category=ErrorCategory.SERVICE, + severity=ErrorSeverity.ERROR, + retry_recommended=kwargs.pop("retry_recommended", True), + error_code=ErrorCode.AX_SVC_ERROR, + **kwargs, + ) + + +class InternalError(AixplainBaseException): + """Raised when there is an internal system error.""" + + def __init__(self, message: str, **kwargs): + # Server errors (5xx) should generally be retryable + status_code = kwargs.get("status_code") + retry_recommended = kwargs.pop("retry_recommended", False) + if status_code and status_code in [500, 502, 503, 504]: + retry_recommended = True + + super().__init__( + message=message, + category=ErrorCategory.INTERNAL, + severity=ErrorSeverity.ERROR, + retry_recommended=retry_recommended, + error_code=ErrorCode.AX_INT_ERROR, + **kwargs, + ) diff --git a/aixplain/factories/agent_factory/__init__.py b/aixplain/factories/agent_factory/__init__.py index 040bcd71..9440ff81 100644 --- a/aixplain/factories/agent_factory/__init__.py +++ b/aixplain/factories/agent_factory/__init__.py @@ -41,7 +41,7 @@ from aixplain.utils import config from typing import Callable, Dict, List, Optional, Text, Union -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from urllib.parse import urljoin from aixplain.enums import DatabaseSourceType diff --git a/aixplain/factories/api_key_factory.py b/aixplain/factories/api_key_factory.py index c719c26b..3d081e27 100644 --- a/aixplain/factories/api_key_factory.py +++ b/aixplain/factories/api_key_factory.py @@ -3,7 +3,7 @@ import aixplain.utils.config as config from datetime import datetime from typing import Text, List, Optional, Dict, Union -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from aixplain.modules.api_key import APIKey, APIKeyLimits, APIKeyUsageLimit diff --git a/aixplain/factories/benchmark_factory.py b/aixplain/factories/benchmark_factory.py index 3f643a3d..c37f17a8 100644 --- a/aixplain/factories/benchmark_factory.py +++ b/aixplain/factories/benchmark_factory.py @@ -32,7 +32,7 @@ from aixplain.factories.dataset_factory import DatasetFactory from aixplain.factories.model_factory import ModelFactory from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from urllib.parse import urljoin diff --git a/aixplain/factories/corpus_factory.py b/aixplain/factories/corpus_factory.py index db7aa44e..9563ad14 100644 --- a/aixplain/factories/corpus_factory.py +++ b/aixplain/factories/corpus_factory.py @@ -38,7 +38,7 @@ from aixplain.enums.language import Language from aixplain.enums.license import License from aixplain.enums.privacy import Privacy -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from aixplain.utils import config from pathlib import Path from tqdm import tqdm diff --git a/aixplain/factories/data_factory.py b/aixplain/factories/data_factory.py index 1879b321..65aa8a87 100644 --- a/aixplain/factories/data_factory.py +++ b/aixplain/factories/data_factory.py @@ -28,15 +28,11 @@ from aixplain.modules.data import Data from aixplain.enums.data_subtype import DataSubtype from aixplain.enums.data_type import DataType -from aixplain.enums.function import Function from aixplain.enums.language import Language -from aixplain.enums.license import License from aixplain.enums.privacy import Privacy -from aixplain.utils.file_utils import _request_with_retry -from aixplain.utils import config -from typing import Any, Dict, List, Text +from aixplain.utils.request_utils import _request_with_retry +from typing import Dict, Text from urllib.parse import urljoin -from uuid import uuid4 class DataFactory(AssetFactory): diff --git a/aixplain/factories/dataset_factory.py b/aixplain/factories/dataset_factory.py index ca9d993e..3b86b45e 100644 --- a/aixplain/factories/dataset_factory.py +++ b/aixplain/factories/dataset_factory.py @@ -41,7 +41,8 @@ from aixplain.enums.privacy import Privacy from aixplain.utils import config from aixplain.utils.convert_datatype_utils import dict_to_metadata -from aixplain.utils.file_utils import _request_with_retry, s3_to_csv +from aixplain.utils.request_utils import _request_with_retry +from aixplain.utils.file_utils import s3_to_csv from aixplain.utils.validation_utils import dataset_onboarding_validation from pathlib import Path from tqdm import tqdm diff --git a/aixplain/factories/finetune_factory/__init__.py b/aixplain/factories/finetune_factory/__init__.py index 238d0d0c..b6006f0b 100644 --- a/aixplain/factories/finetune_factory/__init__.py +++ b/aixplain/factories/finetune_factory/__init__.py @@ -33,7 +33,7 @@ from aixplain.modules.dataset import Dataset from aixplain.modules.model import Model from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from urllib.parse import urljoin @@ -98,7 +98,7 @@ def create( if prompt_template is not None: prompt_template = validate_prompt(prompt_template, dataset_list) try: - url = urljoin(cls.backend_url, f"sdk/finetune/cost-estimation") + url = urljoin(cls.backend_url, "sdk/finetune/cost-estimation") headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} payload = { "datasets": [ diff --git a/aixplain/factories/metric_factory.py b/aixplain/factories/metric_factory.py index 9f42fb3e..6279ffc1 100644 --- a/aixplain/factories/metric_factory.py +++ b/aixplain/factories/metric_factory.py @@ -22,14 +22,12 @@ """ import logging -import os from typing import List, Optional from aixplain.modules import Metric from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from typing import Dict, Text from urllib.parse import urljoin -from warnings import warn class MetricFactory: @@ -113,7 +111,7 @@ def list( List[Metric]: List of supported metrics """ try: - url = urljoin(cls.backend_url, f"sdk/metrics") + url = urljoin(cls.backend_url, "sdk/metrics") filter_params = {} if model_id is not None: filter_params["modelId"] = model_id diff --git a/aixplain/factories/model_factory/__init__.py b/aixplain/factories/model_factory/__init__.py index b39cc668..85c1ac4f 100644 --- a/aixplain/factories/model_factory/__init__.py +++ b/aixplain/factories/model_factory/__init__.py @@ -27,7 +27,7 @@ from aixplain.modules.model.utility_model import UtilityModel, UtilityModelInput from aixplain.enums import Function, Language, OwnershipType, Supplier, SortBy, SortOrder from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from urllib.parse import urljoin diff --git a/aixplain/factories/model_factory/utils.py b/aixplain/factories/model_factory/utils.py index 1be2186c..589bdf68 100644 --- a/aixplain/factories/model_factory/utils.py +++ b/aixplain/factories/model_factory/utils.py @@ -6,10 +6,11 @@ from aixplain.modules.model.utility_model import UtilityModel, UtilityModelInput from aixplain.enums import DataType, Function, Language, OwnershipType, Supplier, SortBy, SortOrder, AssetStatus from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from datetime import datetime from typing import Dict, Union, List, Optional, Tuple from urllib.parse import urljoin +from aixplain.enums import AssetStatus import requests @@ -61,7 +62,6 @@ def create_model_from_response(response: Dict) -> Model: for param in response["params"] ] input_params = model_params - if not code: if "version" in response and response["version"]: version_link = response["version"]["id"] @@ -74,6 +74,8 @@ def create_model_from_response(response: Dict) -> Model: else: raise Exception("Utility Model Error: Code not found") + status = AssetStatus(response.get("status", AssetStatus.DRAFT.value)) + created_at = None if "createdAt" in response and response["createdAt"]: created_at = datetime.fromisoformat(response["createdAt"].replace("Z", "+00:00")) @@ -96,7 +98,7 @@ def create_model_from_response(response: Dict) -> Model: version=response["version"]["id"], inputs=inputs, temperature=temperature, - status=response.get("status", AssetStatus.DRAFT), + status=status, ) diff --git a/aixplain/factories/pipeline_factory/__init__.py b/aixplain/factories/pipeline_factory/__init__.py index cfbfce54..ba164199 100644 --- a/aixplain/factories/pipeline_factory/__init__.py +++ b/aixplain/factories/pipeline_factory/__init__.py @@ -31,7 +31,7 @@ from aixplain.modules.model import Model from aixplain.modules.pipeline import Pipeline from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from urllib.parse import urljoin from warnings import warn diff --git a/aixplain/factories/team_agent_factory/__init__.py b/aixplain/factories/team_agent_factory/__init__.py index 892d7ded..6a1db846 100644 --- a/aixplain/factories/team_agent_factory/__init__.py +++ b/aixplain/factories/team_agent_factory/__init__.py @@ -31,7 +31,7 @@ from aixplain.modules.team_agent import TeamAgent, InspectorTarget from aixplain.utils import config from aixplain.factories.team_agent_factory.utils import build_team_agent -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry class TeamAgentFactory: diff --git a/aixplain/factories/wallet_factory.py b/aixplain/factories/wallet_factory.py index 1591dc2e..2de28ec4 100644 --- a/aixplain/factories/wallet_factory.py +++ b/aixplain/factories/wallet_factory.py @@ -1,6 +1,6 @@ import aixplain.utils.config as config from aixplain.modules.wallet import Wallet -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry import logging from typing import Text diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 227c0040..8f19d3ec 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -27,25 +27,22 @@ import traceback from aixplain.utils.file_utils import _request_with_retry -from aixplain.enums.function import Function -from aixplain.enums.supplier import Supplier -from aixplain.enums.asset_status import AssetStatus -from aixplain.enums.storage_type import StorageType +from aixplain.enums import Function, Supplier, AssetStatus, StorageType, ResponseStatus from aixplain.modules.model import Model from aixplain.modules.agent.agent_task import AgentTask from aixplain.modules.agent.output_format import OutputFormat from aixplain.modules.agent.tool import Tool from aixplain.modules.agent.agent_response import AgentResponse from aixplain.modules.agent.agent_response_data import AgentResponseData -from aixplain.enums import ResponseStatus from aixplain.modules.agent.utils import process_variables from typing import Dict, List, Text, Optional, Union from urllib.parse import urljoin from aixplain.utils import config +from aixplain.modules.mixins import DeployableMixin -class Agent(Model): +class Agent(Model, DeployableMixin[Tool]): """Advanced AI system capable of performing tasks by leveraging specialized software tools and resources from aiXplain marketplace. Attributes: @@ -212,6 +209,8 @@ def run( poll_url = response["url"] end = time.time() result = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time) + # if result.status == ResponseStatus.FAILED: + # raise Exception("Model failed to run with error: " + result.error_message) result_data = result.get("data") or {} return AgentResponse( status=ResponseStatus.SUCCESS, @@ -419,11 +418,5 @@ def save(self) -> None: """Save the Agent.""" self.update() - def deploy(self) -> None: - assert self.status == AssetStatus.DRAFT, "Agent must be in draft status to be deployed." - assert self.status != AssetStatus.ONBOARDED, "Agent is already deployed." - self.status = AssetStatus.ONBOARDED - self.update() - def __repr__(self): return f"Agent(id={self.id}, name={self.name}, function={self.function})" diff --git a/aixplain/modules/agent/tool/__init__.py b/aixplain/modules/agent/tool/__init__.py index aefa093a..93dc269d 100644 --- a/aixplain/modules/agent/tool/__init__.py +++ b/aixplain/modules/agent/tool/__init__.py @@ -23,6 +23,7 @@ from abc import ABC from typing import Optional, Text from aixplain.utils import config +from aixplain.enums import AssetStatus class Tool(ABC): @@ -40,6 +41,7 @@ def __init__( description: Text, version: Optional[Text] = None, api_key: Optional[Text] = config.TEAM_API_KEY, + status: Optional[AssetStatus] = AssetStatus.DRAFT, **additional_info, ) -> None: """Specialized software or resource designed to assist the AI in executing specific tasks or functions based on user commands. @@ -55,6 +57,7 @@ def __init__( self.version = version self.api_key = api_key self.additional_info = additional_info + self.status = status def to_dict(self): """Converts the tool to a dictionary.""" diff --git a/aixplain/modules/agent/tool/custom_python_code_tool.py b/aixplain/modules/agent/tool/custom_python_code_tool.py index d544d2b1..8ec3ab9e 100644 --- a/aixplain/modules/agent/tool/custom_python_code_tool.py +++ b/aixplain/modules/agent/tool/custom_python_code_tool.py @@ -24,6 +24,7 @@ from typing import Text, Union, Callable, Optional from aixplain.modules.agent.tool import Tool import logging +from aixplain.enums import AssetStatus class CustomPythonCodeTool(Tool): @@ -35,6 +36,8 @@ def __init__( """Custom Python Code Tool""" super().__init__(name=name or "", description=description, **additional_info) self.code = code + self.status = AssetStatus.ONBOARDED # TODO: change to DRAFT when we have a way to onboard the tool + self.validate() def to_dict(self): @@ -67,6 +70,12 @@ def validate(self): ), "Custom Python Code Tool Error: Tool description is required" assert self.code and self.code.strip() != "", "Custom Python Code Tool Error: Code is required" assert self.name and self.name.strip() != "", "Custom Python Code Tool Error: Name is required" + assert self.status in [ + AssetStatus.DRAFT, + AssetStatus.ONBOARDED, + ], "Custom Python Code Tool Error: Status must be DRAFT or ONBOARDED" + + def __repr__(self) -> Text: return f"CustomPythonCodeTool(name={self.name})" diff --git a/aixplain/modules/agent/tool/model_tool.py b/aixplain/modules/agent/tool/model_tool.py index 6a945a15..9b073a84 100644 --- a/aixplain/modules/agent/tool/model_tool.py +++ b/aixplain/modules/agent/tool/model_tool.py @@ -22,8 +22,7 @@ """ from typing import Optional, Union, Text, Dict, List -from aixplain.enums.function import Function -from aixplain.enums.supplier import Supplier +from aixplain.enums import AssetStatus, Function, Supplier from aixplain.modules.agent.tool import Tool from aixplain.modules.model import Model @@ -83,8 +82,26 @@ def __init__( """ name = name or "" super().__init__(name=name, description=description, **additional_info) + status = AssetStatus.ONBOARDED if model is None else AssetStatus.DRAFT + model_id = model # if None, Set id to None as default + self.model_object = None # Store the actual model object for parameter access + + if isinstance(model, Model): + model_id = model.id + status = model.status + self.model_object = model # Store the Model object + elif isinstance(model, Text): + # get model from id + try: + self.model_object = self._get_model(model) # Store the Model object + model_id = self.model_object.id + status = self.model_object.status + except Exception: + raise Exception(f"Model Tool Unavailable. Make sure Model '{model}' exists or you have access to it.") + self.supplier = supplier - self.model = model + self.model = model_id + self.status = status self.function = function self.parameters = parameters self.validate() @@ -109,6 +126,7 @@ def to_dict(self) -> Dict: "version": self.version if self.version else None, "assetId": self.model.id if self.model is not None and isinstance(self.model, Model) else self.model, "parameters": self.parameters, + "status": self.status, } def validate(self) -> None: @@ -123,7 +141,6 @@ def validate(self) -> None: - If the description is empty, it sets the description to the function description or the model description. """ from aixplain.enums import FunctionInputOutput - from aixplain.factories.model_factory import ModelFactory assert ( self.function is not None or self.model is not None @@ -145,7 +162,7 @@ def validate(self) -> None: if self.model is not None: if isinstance(self.model, Text) is True: try: - self.model = ModelFactory.get(self.model, api_key=self.api_key) + self.model = self._get_model() except Exception: raise Exception(f"Model Tool Unavailable. Make sure Model '{self.model}' exists or you have access to it.") self.function = self.model.function @@ -166,8 +183,22 @@ def validate(self) -> None: self.name = self.name if self.name else set_tool_name(self.function, self.supplier, self.model) def get_parameters(self) -> Dict: + # If parameters were not explicitly provided, get them from the model + if ( + self.parameters is None + and self.model_object is not None # noqa: W503 + and hasattr(self.model_object, "model_params") # noqa: W503 + and self.model_object.model_params is not None # noqa: W503 + ): + return self.model_object.model_params.to_list() return self.parameters + def _get_model(self, model_id: Text = None): + from aixplain.factories.model_factory import ModelFactory + + model_id = model_id or self.model + return ModelFactory.get(model_id, api_key=self.api_key) + def validate_parameters(self, received_parameters: Optional[List[Dict]] = None) -> Optional[List[Dict]]: """Validates and formats the parameters for the tool. @@ -182,8 +213,12 @@ def validate_parameters(self, received_parameters: Optional[List[Dict]] = None) """ if received_parameters is None: # Get default parameters if none provided - if self.model is not None and self.model.model_params is not None: - return self.model.model_params.to_list() + if ( + self.model_object is not None + and hasattr(self.model_object, "model_params") # noqa: W503 + and self.model_object.model_params is not None # noqa: W503 + ): + return self.model_object.model_params.to_list() elif self.function is not None: function_params = self.function.get_parameters() if function_params is not None: @@ -192,8 +227,12 @@ def validate_parameters(self, received_parameters: Optional[List[Dict]] = None) # Get expected parameters expected_params = None - if self.model is not None and self.model.model_params is not None: - expected_params = self.model.model_params + if ( + self.model_object is not None + and hasattr(self.model_object, "model_params") # noqa: W503 + and self.model_object.model_params is not None # noqa: W503 + ): + expected_params = self.model_object.model_params elif self.function is not None: expected_params = self.function.get_parameters() diff --git a/aixplain/modules/agent/tool/pipeline_tool.py b/aixplain/modules/agent/tool/pipeline_tool.py index 0de83916..728256d2 100644 --- a/aixplain/modules/agent/tool/pipeline_tool.py +++ b/aixplain/modules/agent/tool/pipeline_tool.py @@ -24,6 +24,7 @@ from aixplain.modules.agent.tool import Tool from aixplain.modules.pipeline import Pipeline +from aixplain.enums import AssetStatus class PipelineTool(Tool): @@ -50,6 +51,8 @@ def __init__( name = name or "" super().__init__(name=name, description=description, **additional_info) + self.status = AssetStatus.DRAFT + self.pipeline = pipeline self.validate() @@ -59,8 +62,12 @@ def to_dict(self): "name": self.name, "description": self.description, "type": "pipeline", + "status": self.status, } + def __repr__(self) -> Text: + return f"PipelineTool(name={self.name}, pipeline={self.pipeline})" + def validate(self): from aixplain.factories.pipeline_factory import PipelineFactory @@ -76,6 +83,5 @@ def validate(self): if self.name.strip() == "": self.name = pipeline_obj.name + self.status = pipeline_obj.status - def __repr__(self) -> Text: - return f"PipelineTool(name={self.name}, pipeline={self.pipeline})" diff --git a/aixplain/modules/agent/tool/python_interpreter_tool.py b/aixplain/modules/agent/tool/python_interpreter_tool.py index 42621c45..5f6f93f6 100644 --- a/aixplain/modules/agent/tool/python_interpreter_tool.py +++ b/aixplain/modules/agent/tool/python_interpreter_tool.py @@ -22,6 +22,8 @@ """ from aixplain.modules.agent.tool import Tool +from aixplain.enums import AssetStatus + from typing import Text @@ -32,6 +34,7 @@ def __init__(self, **additional_info) -> None: """Python Interpreter Tool""" description = "A Python shell. Use this to execute python commands. Input should be a valid python command." super().__init__(name="Python Interpreter", description=description, **additional_info) + self.status = AssetStatus.ONBOARDED # TODO: change to DRAFT when we have a way to onboard the tool def to_dict(self): return { diff --git a/aixplain/modules/agent/tool/sql_tool.py b/aixplain/modules/agent/tool/sql_tool.py index 56cf116c..f2262ee9 100644 --- a/aixplain/modules/agent/tool/sql_tool.py +++ b/aixplain/modules/agent/tool/sql_tool.py @@ -27,7 +27,7 @@ import numpy as np from typing import Text, Optional, Dict, List, Union import sqlite3 - +from aixplain.enums import AssetStatus from aixplain.modules.agent.tool import Tool @@ -284,6 +284,7 @@ def __init__( self.schema = schema self.tables = tables if isinstance(tables, list) else [tables] if tables else None self.enable_commit = enable_commit + self.status = AssetStatus.ONBOARDED # TODO: change to DRAFT when we have a way to onboard the tool def to_dict(self) -> Dict[str, Text]: return { diff --git a/aixplain/modules/api_key.py b/aixplain/modules/api_key.py index ae774c23..d606e106 100644 --- a/aixplain/modules/api_key.py +++ b/aixplain/modules/api_key.py @@ -1,6 +1,6 @@ import logging from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from aixplain.modules import Model from datetime import datetime from typing import Dict, List, Optional, Text, Union diff --git a/aixplain/modules/benchmark.py b/aixplain/modules/benchmark.py index d76b2e62..6878becf 100644 --- a/aixplain/modules/benchmark.py +++ b/aixplain/modules/benchmark.py @@ -26,7 +26,7 @@ from aixplain.modules import Asset, Dataset, Metric, Model from aixplain.modules.benchmark_job import BenchmarkJob from urllib.parse import urljoin -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry class Benchmark(Asset): diff --git a/aixplain/modules/benchmark_job.py b/aixplain/modules/benchmark_job.py index 29a33aa7..cd17c0e1 100644 --- a/aixplain/modules/benchmark_job.py +++ b/aixplain/modules/benchmark_job.py @@ -4,7 +4,8 @@ from urllib.parse import urljoin import pandas as pd from pathlib import Path -from aixplain.utils.file_utils import _request_with_retry, save_file +from aixplain.utils.request_utils import _request_with_retry +from aixplain.utils.file_utils import save_file class BenchmarkJob: @@ -72,7 +73,7 @@ def download_results_as_csv(self, save_path: Optional[Text] = None, return_dataf logging.info(f"Downloading Benchmark Results: Status of downloading results for {self.id}: {resp}") if "reportUrl" not in resp or resp["reportUrl"] == "": logging.error( - f"Downloading Benchmark Results: Can't get download results as they aren't generated yet. Please wait for a while." + "Downloading Benchmark Results: Can't get download results as they aren't generated yet. Please wait for a while." ) return None csv_url = resp["reportUrl"] diff --git a/aixplain/modules/corpus.py b/aixplain/modules/corpus.py index b65664b6..10101292 100644 --- a/aixplain/modules/corpus.py +++ b/aixplain/modules/corpus.py @@ -29,8 +29,8 @@ from aixplain.modules.asset import Asset from aixplain.modules.data import Data from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry -from typing import Any, List, Optional, Text +from aixplain.utils.request_utils import _request_with_retry +from typing import List, Optional, Text from urllib.parse import urljoin diff --git a/aixplain/modules/dataset.py b/aixplain/modules/dataset.py index fd79e9f3..85264013 100644 --- a/aixplain/modules/dataset.py +++ b/aixplain/modules/dataset.py @@ -30,7 +30,7 @@ from aixplain.modules.asset import Asset from aixplain.modules.data import Data from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from urllib.parse import urljoin from typing import Any, Dict, List, Optional, Text diff --git a/aixplain/modules/finetune/__init__.py b/aixplain/modules/finetune/__init__.py index fe2cb15c..15cc37a7 100644 --- a/aixplain/modules/finetune/__init__.py +++ b/aixplain/modules/finetune/__init__.py @@ -31,7 +31,7 @@ from aixplain.modules.model import Model from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry class Finetune(Asset): diff --git a/aixplain/modules/mixins.py b/aixplain/modules/mixins.py new file mode 100644 index 00000000..0b402cf2 --- /dev/null +++ b/aixplain/modules/mixins.py @@ -0,0 +1,74 @@ +""" +Copyright 2024 The aiXplain SDK authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Author: Thiago Castro Ferreira, Shreyas Sharma and Lucas Pavanelli +Date: November 25th 2024 +Description: + Mixins for common functionality across different asset types +""" +from abc import ABC +from typing import TypeVar, Generic +from aixplain.enums import AssetStatus + +T = TypeVar("T") + + +class DeployableMixin(ABC, Generic[T]): + """A mixin that provides common deployment-related functionality for assets. + + This mixin provides methods for: + 1. Filtering items that are not onboarded + 2. Validating if an asset is ready to be deployed + 3. Deploying an asset + + Classes that inherit from this mixin should: + 1. Implement _validate_deployment_readiness to call the parent implementation with their specific asset type + 2. Optionally override deploy() if they need special deployment handling + """ + + def _validate_deployment_readiness(self) -> None: + """Validate if the asset is ready to be deployed. + + Args: + asset_type (str): Type of asset being validated (e.g. "Agent", "Team Agent", "Pipeline") + items (Optional[List[T]], optional): List of items to validate (e.g. tools for Agent, agents for TeamAgent) + + Raises: + ValueError: If the asset is not ready to be deployed + """ + asset_type = self.__class__.__name__ + if self.status == AssetStatus.ONBOARDED: + raise ValueError(f"{asset_type} is already deployed.") + + if self.status != AssetStatus.DRAFT: + raise ValueError(f"{asset_type} must be in DRAFT status to be deployed.") + + def deploy(self) -> None: + """Deploy the asset. + + This method validates that the asset is ready to be deployed and updates its status to ONBOARDED. + Classes that need special deployment handling should override this method. + + Raises: + ValueError: If the asset is not ready to be deployed + """ + self._validate_deployment_readiness() + previous_status = self.status + try: + self.status = AssetStatus.ONBOARDED + self.update() + except Exception as e: + self.status = previous_status + raise Exception(f"Error deploying because of backend error: {e}") from e diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index adedfcfb..49cf4591 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -28,12 +28,13 @@ from aixplain.modules.model.utils import build_payload, call_run_endpoint from aixplain.utils import config from urllib.parse import urljoin -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from typing import Union, Optional, Text, Dict from datetime import datetime from aixplain.modules.model.response import ModelResponse from aixplain.enums.response_status import ResponseStatus from aixplain.modules.model.model_parameters import ModelParameters +from aixplain.enums import AssetStatus class Model(Asset): @@ -72,6 +73,7 @@ def __init__( input_params: Optional[Dict] = None, output_params: Optional[Dict] = None, model_params: Optional[Dict] = None, + status: Optional[AssetStatus] = AssetStatus.ONBOARDED, # default status for models is ONBOARDED **additional_info, ) -> None: """Model Init @@ -89,6 +91,7 @@ def __init__( input_params (Dict, optional): input parameters for the function. output_params (Dict, optional): output parameters for the function. model_params (Dict, optional): parameters for the function. + status (AssetStatus, optional): status of the model. Defaults to None. **additional_info: Any additional Model info to be saved """ super().__init__(id, name, description, supplier, version, cost=cost) @@ -102,6 +105,12 @@ def __init__( self.input_params = input_params self.output_params = output_params self.model_params = ModelParameters(model_params) if model_params else None + if isinstance(status, str): + try: + status = AssetStatus(status) + except Exception: + status = AssetStatus.ONBOARDED + self.status = status def to_dict(self) -> Dict: """Get the model info as a Dictionary @@ -119,6 +128,7 @@ def to_dict(self) -> Dict: "input_params": self.input_params, "output_params": self.output_params, "model_params": self.model_params.to_dict(), + "status": self.status, } def get_parameters(self) -> ModelParameters: @@ -199,6 +209,7 @@ def poll(self, poll_url: Text, name: Text = "model_process") -> ModelResponse: status = ResponseStatus.FAILED else: status = ResponseStatus.IN_PROGRESS + logging.debug(f"Single Poll for Model: Status of polling for {name}: {resp}") return ModelResponse( status=resp.pop("status", status), @@ -209,6 +220,7 @@ def poll(self, poll_url: Text, name: Text = "model_process") -> ModelResponse: used_credits=resp.pop("usedCredits", 0), run_time=resp.pop("runTime", 0), usage=resp.pop("usage", None), + error_code=resp.get("error_code", None), **resp, ) except Exception as e: @@ -264,6 +276,7 @@ def run( used_credits=response.pop("usedCredits", 0), run_time=response.pop("runTime", 0), usage=response.pop("usage", None), + error_code=response.get("error_code", None), **response, ) @@ -306,7 +319,7 @@ def check_finetune_status(self, after_epoch: Optional[int] = None): Returns: FinetuneStatus: The status of the FineTune model. """ - from aixplain.enums.asset_status import AssetStatus + from aixplain.enums import AssetStatus from aixplain.modules.finetune.status import FinetuneStatus headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} diff --git a/aixplain/modules/model/llm_model.py b/aixplain/modules/model/llm_model.py index cf60d0a2..0b8e6cf8 100644 --- a/aixplain/modules/model/llm_model.py +++ b/aixplain/modules/model/llm_model.py @@ -166,6 +166,7 @@ def run( used_credits=response.pop("usedCredits", 0), run_time=response.pop("runTime", 0), usage=response.pop("usage", None), + error_code=response.get("error_code", None), **response, ) diff --git a/aixplain/modules/model/response.py b/aixplain/modules/model/response.py index ac9f8184..a0cf08f8 100644 --- a/aixplain/modules/model/response.py +++ b/aixplain/modules/model/response.py @@ -1,5 +1,6 @@ from typing import Text, Any, Optional, Dict, List, Union from aixplain.enums import ResponseStatus +from aixplain.exceptions.types import ErrorCode class ModelResponse: @@ -16,6 +17,7 @@ def __init__( run_time: float = 0.0, usage: Optional[Dict] = None, url: Optional[Text] = None, + error_code: Optional[ErrorCode] = None, **kwargs, ): self.status = status @@ -31,6 +33,7 @@ def __init__( self.run_time = run_time self.usage = usage self.url = url + self.error_code = error_code self.additional_fields = kwargs def __getitem__(self, key: Text) -> Any: @@ -82,6 +85,8 @@ def __repr__(self) -> str: fields.append(f"usage={self.usage}") if self.url: fields.append(f"url='{self.url}'") + if self.error_code: + fields.append(f"error_code='{self.error_code}'") if self.additional_fields: fields.extend([f"{k}={repr(v)}" for k, v in self.additional_fields.items()]) return f"ModelResponse({', '.join(fields)})" @@ -104,6 +109,7 @@ def to_dict(self) -> Dict[Text, Any]: "run_time": self.run_time, "usage": self.usage, "url": self.url, + "error_code": self.error_code, } if self.additional_fields: base_dict.update(self.additional_fields) diff --git a/aixplain/modules/model/utility_model.py b/aixplain/modules/model/utility_model.py index 96454181..1cbbf3da 100644 --- a/aixplain/modules/model/utility_model.py +++ b/aixplain/modules/model/utility_model.py @@ -21,14 +21,15 @@ import logging import warnings from aixplain.enums import Function, Supplier, DataType -from aixplain.enums.asset_status import AssetStatus +from aixplain.enums import AssetStatus from aixplain.modules.model import Model from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from aixplain.modules.model.utils import parse_code_decorated from dataclasses import dataclass from typing import Callable, Union, Optional, List, Text, Dict from urllib.parse import urljoin +from aixplain.modules.mixins import DeployableMixin @dataclass @@ -88,7 +89,7 @@ def decorator(func): return decorator -class UtilityModel(Model): +class UtilityModel(Model, DeployableMixin): """Ready-to-use Utility Model. Note: Non-deployed utility models (status=DRAFT) will expire after 24 hours after creation. @@ -107,6 +108,7 @@ class UtilityModel(Model): function (Function, optional): model AI function. Defaults to None. is_subscribed (bool, optional): Is the user subscribed. Defaults to False. cost (Dict, optional): model price. Defaults to None. + status (AssetStatus, optional): status of the model. Defaults to AssetStatus.DRAFT. **additional_info: Any additional Model info to be saved """ @@ -155,6 +157,7 @@ def __init__( function=function, is_subscribed=is_subscribed, api_key=api_key, + status=status, **additional_info, ) self.url = config.MODELS_RUN_URL @@ -274,9 +277,3 @@ def delete(self): message = f"Utility Model Deletion Error: {response}" logging.error(message) raise Exception(f"{message}") - - def deploy(self) -> None: - assert self.status == AssetStatus.DRAFT, "Utility Model must be in draft status to be deployed." - assert self.status != AssetStatus.ONBOARDED, "Utility Model is already deployed." - self.status = AssetStatus.ONBOARDED - self.update() diff --git a/aixplain/modules/model/utils.py b/aixplain/modules/model/utils.py index 7ba42f2d..6f3f9319 100644 --- a/aixplain/modules/model/utils.py +++ b/aixplain/modules/model/utils.py @@ -4,6 +4,7 @@ import logging from aixplain.utils.file_utils import _request_with_retry from typing import Callable, Dict, List, Text, Tuple, Union, Optional +from aixplain.exceptions import get_error_from_status_code def build_payload(data: Union[Text, Dict], parameters: Optional[Dict] = None): @@ -61,22 +62,12 @@ def call_run_endpoint(url: Text, api_key: Text, payload: Dict) -> Dict: else: response = resp else: - resp = resp["error"] if isinstance(resp, dict) and "error" in resp else resp - if r.status_code == 401: - error = f"Unauthorized API key: Please verify the spelling of the API key and its current validity. Details: {resp}" - elif 460 <= r.status_code < 470: - error = f"Subscription-related error: Please ensure that your subscription is active and has not expired. Details: {resp}" - elif 470 <= r.status_code < 480: - error = f"Billing-related error: Please ensure you have enough credits to run this model. Details: {resp}" - elif 480 <= r.status_code < 490: - error = f"Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access. Details: {resp}" - elif 490 <= r.status_code < 500: - error = f"{resp}" - else: - status_code = str(r.status_code) - error = f"Status {status_code} - Unspecified error: {resp}" - response = {"status": "FAILED", "error_message": error, "completed": True} + error_details = resp["error"] if isinstance(resp, dict) and "error" in resp else resp + status_code = r.status_code + error = get_error_from_status_code(status_code, error_details) + logging.error(f"Error in request: {r.status_code}: {error}") + response = {"status": "FAILED", "error_message": error.message, "completed": True} return response diff --git a/aixplain/modules/pipeline/asset.py b/aixplain/modules/pipeline/asset.py index 0642e6ee..cc234337 100644 --- a/aixplain/modules/pipeline/asset.py +++ b/aixplain/modules/pipeline/asset.py @@ -1,7 +1,7 @@ __author__ = "aiXplain" """ -Copyright 2022 The aiXplain SDK authors +Copyright 2024 The aiXplain SDK authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,27 +15,28 @@ See the License for the specific language governing permissions and limitations under the License. -Author: Duraikrishna Selvaraju, Thiago Castro Ferreira, Shreyas Sharma and Lucas Pavanelli -Date: September 1st 2022 +Author: Thiago Castro Ferreira, Shreyas Sharma and Lucas Pavanelli +Date: November 25th 2024 Description: - Pipeline Class + Pipeline Asset Class """ import time import json import os import logging -from aixplain.enums.asset_status import AssetStatus -from aixplain.enums.response_status import ResponseStatus -from aixplain.modules.asset import Asset +from aixplain.enums import AssetStatus, ResponseStatus +from aixplain.modules import Asset from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from typing import Dict, Optional, Text, Union from urllib.parse import urljoin from aixplain.modules.pipeline.response import PipelineResponse +from aixplain.modules.mixins import DeployableMixin +from aixplain.exceptions import get_error_from_status_code -class Pipeline(Asset): +class Pipeline(Asset, DeployableMixin): """Representing a custom pipeline that was created on the aiXplain Platform Attributes: @@ -45,6 +46,7 @@ class Pipeline(Asset): url (Text, optional): running URL of platform. Defaults to config.BACKEND_URL. supplier (Text, optional): Pipeline supplier. Defaults to "aiXplain". version (Text, optional): version of the pipeline. Defaults to "1.0". + status (AssetStatus, optional): Pipeline status. Defaults to AssetStatus.DRAFT. **additional_info: Any additional Pipeline info to be saved """ @@ -410,33 +412,20 @@ def run_async( return res else: - if r.status_code == 401: - error = "Unauthorized API key: Please verify the spelling of the API key and its current validity." - elif 460 <= r.status_code < 470: - error = "Subscription-related error: Please ensure that your subscription is active and has not expired." - elif 470 <= r.status_code < 480: - error = "Billing-related error: Please ensure you have enough credits to run this pipeline. " - elif 480 <= r.status_code < 490: - error = "Supplier-related error: Please ensure that the selected supplier provides the pipeline you are trying to access." - elif 490 <= r.status_code < 500: - error = "Validation-related error: Please ensure all required fields are provided and correctly formatted." - else: - status_code = str(r.status_code) - error = ( - f"Status {status_code}: Unspecified error: An unspecified error occurred while processing your request." - ) + status_code = r.status_code + error = get_error_from_status_code(status_code) logging.error(f"Error in request for {name} (Pipeline ID '{self.id}') - {r.status_code}: {error}") if response_version == "v1": return { "status": "failed", - "error": error, + "error": error.message, "elapsed_time": None, **kwargs, } return PipelineResponse( status=ResponseStatus.FAILED, - error={"error": error, "status": "ERROR"}, + error={"error": error.message, "status": "ERROR"}, elapsed_time=None, **kwargs, ) @@ -593,13 +582,23 @@ def save( raise Exception(e) def deploy(self, api_key: Optional[Text] = None) -> None: - """Deploy the Pipeline.""" - assert self.status == "draft", "Pipeline Deployment Error: Pipeline must be in draft status." - assert self.status != "onboarded", "Pipeline Deployment Error: Pipeline must be onboarded." + """Deploy the Pipeline. + This method overrides the deploy method in DeployableMixin to handle + Pipeline-specific deployment functionality. + + Args: + api_key (Optional[Text], optional): Team API Key to deploy the Pipeline. Defaults to None. + """ + self._validate_deployment_readiness() pipeline = self.to_dict() - self.update(pipeline=pipeline, save_as_asset=True, api_key=api_key, name=self.name) - self.status = AssetStatus.ONBOARDED + previous_status = self.status + try: + self.status = AssetStatus.ONBOARDED + self.update(pipeline=pipeline, save_as_asset=True, api_key=api_key, name=self.name) + except Exception as e: + self.status = previous_status + raise Exception(f"Error deploying because of backend error: {e}") from e def __repr__(self): return f"Pipeline(id={self.id}, name={self.name})" diff --git a/aixplain/modules/team_agent/__init__.py b/aixplain/modules/team_agent/__init__.py index 3014adb1..449e9549 100644 --- a/aixplain/modules/team_agent/__init__.py +++ b/aixplain/modules/team_agent/__init__.py @@ -42,6 +42,7 @@ from aixplain.modules.agent.utils import process_variables from aixplain.utils import config from aixplain.utils.file_utils import _request_with_retry +from aixplain.modules.mixins import DeployableMixin class InspectorTarget(str, Enum): @@ -53,7 +54,7 @@ def __str__(self): return self._value_ -class TeamAgent(Model): +class TeamAgent(Model, DeployableMixin[Agent]): """Advanced AI system capable of using multiple agents to perform a variety of tasks. Attributes: @@ -416,14 +417,3 @@ def update(self) -> None: else: error_msg = f"Team Agent Update Error (HTTP {r.status_code}): {resp}" raise Exception(error_msg) - - def save(self) -> None: - """Save the Team Agent.""" - self.update() - - def deploy(self) -> None: - """Deploy the Team Agent.""" - assert self.status == AssetStatus.DRAFT, "Team Agent Deployment Error: Team Agent must be in draft status." - assert self.status != AssetStatus.ONBOARDED, "Team Agent Deployment Error: Team Agent must be onboarded." - self.status = AssetStatus.ONBOARDED - self.update() diff --git a/aixplain/processes/data_onboarding/onboard_functions.py b/aixplain/processes/data_onboarding/onboard_functions.py index 01a3fe9b..09d1b153 100644 --- a/aixplain/processes/data_onboarding/onboard_functions.py +++ b/aixplain/processes/data_onboarding/onboard_functions.py @@ -18,7 +18,7 @@ from aixplain.modules.dataset import Dataset from aixplain.modules.file import File from aixplain.modules.metadata import MetaData -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Text, Union from urllib.parse import urljoin @@ -203,11 +203,10 @@ def build_payload_dataset( "description": dataset.description, "function": dataset.function.value, "onboardingErrorsPolicy": error_handler.value, - "tags": dataset.tags, + "tags": dataset.tags or tags, "privacy": dataset.privacy.value, "license": {"typeId": dataset.license.value}, "refData": ref_data, - "tags": tags, "data": [], "input": [], "hypotheses": [], diff --git a/aixplain/utils/file_utils.py b/aixplain/utils/file_utils.py index 58781dfb..f2bb55bc 100644 --- a/aixplain/utils/file_utils.py +++ b/aixplain/utils/file_utils.py @@ -20,10 +20,9 @@ import aixplain.utils.config as config from aixplain.enums.license import License - +from aixplain.utils.request_utils import _request_with_retry from collections import defaultdict from pathlib import Path -from requests.adapters import HTTPAdapter, Retry from typing import Any, Optional, Text, Union, Dict, List from uuid import uuid4 from urllib.parse import urljoin, urlparse @@ -52,24 +51,6 @@ def save_file(download_url: Text, download_file_path: Optional[Any] = None) -> A return download_file_path -def _request_with_retry(method: Text, url: Text, **params) -> requests.Response: - """Wrapper around requests with Session to retry in case it fails - - Args: - method (Text): HTTP method, such as 'GET' or 'HEAD'. - url (Text): The URL of the resource to fetch. - **params: Params to pass to request function - - Returns: - requests.Response: Response object of the request - """ - session = requests.Session() - retries = Retry(total=5, backoff_factor=0.1, status_forcelist=[500, 502, 503, 504]) - session.mount("https://", HTTPAdapter(max_retries=retries)) - response = session.request(method=method.upper(), url=url, **params) - return response - - def download_data(url_link, local_filename=None): if local_filename is None: local_filename = url_link.split("/")[-1] diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 0a19a8d3..73bab1fe 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -369,7 +369,6 @@ def test_sql_tool(delete_agents_and_team_agents, AgentFactory): ) assert tool is not None assert tool.description == "Execute an SQL query and return the result" - agent = AgentFactory.create( name="Teste", @@ -397,7 +396,6 @@ def test_sql_tool(delete_agents_and_team_agents, AgentFactory): os.remove("ftest.db") agent.delete() - @pytest.mark.parametrize("AgentFactory", [AgentFactory, v2.Agent]) def test_sql_tool_with_csv(delete_agents_and_team_agents, AgentFactory): assert delete_agents_and_team_agents @@ -597,4 +595,4 @@ def test_agent_with_pipeline_tool(delete_agents_and_team_agents, AgentFactory): assert "hello" in answer["data"]["output"].lower() assert "hello pipeline" in answer["data"]["intermediate_steps"][0]["tool_steps"][0]["tool"].lower() - + diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index 916d6077..dd8cb04e 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -99,6 +99,8 @@ def run_index_model(index_model): pytest.param(EmbeddingModel.SNOWFLAKE_ARCTIC_EMBED_L_V2_0, AirParams, id="AIR - Snowflake Arctic Embed L v2.0"), pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, AirParams, id="AIR - Multilingual E5 Large"), pytest.param(EmbeddingModel.BGE_M3, AirParams, id="AIR - BGE M3"), + pytest.param(EmbeddingModel.AIXPLAIN_LEGAL_EMBEDDINGS, AirParams, id="AIR - aiXplain Legal Embeddings"), + ], ) def test_index_model(embedding_model, supplier_params): @@ -123,6 +125,8 @@ def test_index_model(embedding_model, supplier_params): pytest.param(EmbeddingModel.JINA_CLIP_V2_MULTIMODAL, AirParams, id="Jina Clip v2 Multimodal"), pytest.param(EmbeddingModel.MULTILINGUAL_E5_LARGE, AirParams, id="Multilingual E5 Large"), pytest.param(EmbeddingModel.BGE_M3, AirParams, id="BGE M3"), + pytest.param(EmbeddingModel.AIXPLAIN_LEGAL_EMBEDDINGS, AirParams, id="aiXplain Legal Embeddings"), + ], ) def test_index_model_with_filter(embedding_model, supplier_params): @@ -235,3 +239,8 @@ def test_index_model_air_with_image(): assert "hurricane" in second_record.lower() index_model.delete() + + import os + + if os.path.exists("hurricane.jpeg"): + os.remove("hurricane.jpeg") diff --git a/tests/test_utils.py b/tests/test_utils.py index 264538d5..3d93fe25 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,4 @@ -from aixplain.utils.file_utils import _request_with_retry +from aixplain.utils.request_utils import _request_with_retry from urllib.parse import urljoin import logging from aixplain.utils import config @@ -12,7 +12,7 @@ def delete_asset(model_id, api_key): def delete_service_account(api_key): - delete_url = urljoin(config.BACKEND_URL, f"sdk/ecr/logout") + delete_url = urljoin(config.BACKEND_URL, "sdk/ecr/logout") logging.debug(f"URL: {delete_url}") headers = {"Authorization": f"Token {api_key}", "Content-Type": "application/json"} _ = _request_with_retry("post", delete_url, headers=headers) diff --git a/tests/unit/agent/model_tool_test.py b/tests/unit/agent/model_tool_test.py index 869979e1..5dfc736e 100644 --- a/tests/unit/agent/model_tool_test.py +++ b/tests/unit/agent/model_tool_test.py @@ -4,11 +4,10 @@ from unittest.mock import MagicMock from aixplain.modules.agent.tool.model_tool import ModelTool -from aixplain.enums.function import Function -from aixplain.enums.supplier import Supplier from aixplain.modules.model import Model from aixplain.modules.model.model_parameters import ModelParameters from aixplain.base.parameters import Parameter +from aixplain.enums import AssetStatus, Function, Supplier @pytest.fixture @@ -19,6 +18,7 @@ def mock_model(): model.supplier = Supplier.AIXPLAIN model.name = "Test Model" model.description = "Test Model Description" + model.status = AssetStatus.ONBOARDED model.model_params = ModelParameters( { "sourcelanguage": {"name": "sourcelanguage", "required": True}, @@ -107,6 +107,7 @@ def test_to_dict(mocker, mock_model, mock_model_factory): "version": None, "assetId": "test_model_id", "parameters": [{"name": "sourcelanguage", "value": "en"}, {"name": "targetlanguage", "value": "es"}], + "status": mock_model.status.value, } result = tool.to_dict() diff --git a/tests/unit/llm_test.py b/tests/unit/llm_test.py index 753a8f7a..5d188470 100644 --- a/tests/unit/llm_test.py +++ b/tests/unit/llm_test.py @@ -24,17 +24,17 @@ ), ( 475, - "Billing-related error: Please ensure you have enough credits to run this model. Details: An unspecified error occurred while processing your request.", + "Billing-related error: Please ensure you have enough credits to run this asset. Details: An unspecified error occurred while processing your request.", ), ( 485, - "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access. Details: An unspecified error occurred while processing your request.", + "Supplier-related error: Please ensure that the selected supplier provides the asset you are trying to access. Details: An unspecified error occurred while processing your request.", ), ( 495, - "An unspecified error occurred while processing your request.", + "Validation-related error: Please verify the request payload and ensure it is correct. Details: An unspecified error occurred while processing your request.", ), - (501, "Status 501 - Unspecified error: An unspecified error occurred while processing your request."), + (501, "Unspecified Server Error (Status 501) Details: An unspecified error occurred while processing your request."), ], ) def test_run_async_errors(status_code, error_message): diff --git a/tests/unit/mock_responses/list_models_response.json b/tests/unit/mock_responses/list_models_response.json index de2cb7ba..c927ba0b 100644 --- a/tests/unit/mock_responses/list_models_response.json +++ b/tests/unit/mock_responses/list_models_response.json @@ -1,4 +1,3 @@ - { "total": 15, "pageTotal": 10, @@ -7,6 +6,7 @@ "id": "test_asset_id", "name": "Translate from English to French (Transformerml6x6)", "serviceName": "Machine Translation", + "status": "onboarded", "supplier": { "id": "Nvidia", "name": "Nvidia" @@ -129,6 +129,7 @@ "id": "test_asset_id", "name": "Translate from English to French (Transformer12x2)", "serviceName": "Machine Translation", + "status": "onboarded", "supplier": { "id": "Nvidia", "name": "Nvidia" @@ -250,6 +251,7 @@ { "id": "test_asset_id", "name": "Translate from English to French", + "status": "onboarded", "supplier": { "id": "eBay", "name": "eBay" @@ -376,6 +378,7 @@ "id": "test_asset_id", "name": "Translate from English to French (Canada)", "serviceName": "Amazon Translate", + "status": "onboarded", "supplier": { "id": "AWS", "name": "AWS" @@ -507,6 +510,7 @@ "id": "test_asset_id", "name": "Translate from English to French", "serviceName": "Cloud Translation", + "status": "onboarded", "supplier": { "id": "Google", "name": "Google" @@ -632,6 +636,7 @@ { "id": "test_asset_id", "name": "Translate from English to French", + "status": "onboarded", "supplier": { "id": "ModernMT", "name": "ModernMT" @@ -758,6 +763,7 @@ "id": "test_asset_id", "name": "Translate from English to French", "serviceName": "Cognitive Services", + "status": "onboarded", "supplier": { "id": "Azure", "name": "Azure" @@ -884,6 +890,7 @@ "id": "test_asset_id", "name": "Translate from English to French (Canada)", "serviceName": "Cognitive Services", + "status": "onboarded", "supplier": { "id": "Azure", "name": "Azure" @@ -1010,6 +1017,7 @@ { "id": "test_asset_id", "name": "Translate from English to French", + "status": "onboarded", "supplier": { "id": "AppTek", "name": "AppTek" @@ -1136,6 +1144,7 @@ "id": "test_asset_id", "name": "Translate from English to French (Transformerml24x6)", "serviceName": "Machine Translation", + "status": "onboarded", "supplier": { "id": "Nvidia", "name": "Nvidia" diff --git a/tests/unit/model_test.py b/tests/unit/model_test.py index 65708d0c..943d501d 100644 --- a/tests/unit/model_test.py +++ b/tests/unit/model_test.py @@ -121,17 +121,17 @@ def test_failed_poll(): ), ( 475, - "Billing-related error: Please ensure you have enough credits to run this model. Details: An unspecified error occurred while processing your request.", + "Billing-related error: Please ensure you have enough credits to run this asset. Details: An unspecified error occurred while processing your request.", ), ( 485, - "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access. Details: An unspecified error occurred while processing your request.", + "Supplier-related error: Please ensure that the selected supplier provides the asset you are trying to access. Details: An unspecified error occurred while processing your request.", ), ( 495, - "An unspecified error occurred while processing your request.", + "Validation-related error: Please verify the request payload and ensure it is correct. Details: An unspecified error occurred while processing your request.", ), - (501, "Status 501 - Unspecified error: An unspecified error occurred while processing your request."), + (501, "Unspecified Server Error (Status 501) Details: An unspecified error occurred while processing your request."), ], ) def test_run_async_errors(status_code, error_message): @@ -204,6 +204,7 @@ def test_get_model_from_ids(): { "id": "test-model-id-1", "name": "Test Model 1", + "status": "onboarded", "description": "Test Description 1", "function": {"id": "text-generation"}, "supplier": {"id": "aiXplain"}, @@ -214,6 +215,7 @@ def test_get_model_from_ids(): { "id": "test-model-id-2", "name": "Test Model 2", + "status": "onboarded", "description": "Test Description 2", "function": {"id": "text-generation"}, "supplier": {"id": "aiXplain"}, @@ -237,9 +239,9 @@ def test_list_models_error(): with pytest.raises(Exception) as excinfo: ModelFactory.list(model_ids=model_ids, function=Function.TEXT_GENERATION, api_key=config.AIXPLAIN_API_KEY) - assert ( - str(excinfo.value) - == "Cannot filter by function, suppliers, source languages, target languages, is finetunable, ownership, sort by when using model ids" + assert str(excinfo.value) == ( + "Cannot filter by function, suppliers, " + "source languages, target languages, is finetunable, ownership, sort by when using model ids" ) with pytest.raises(Exception) as excinfo: diff --git a/tests/unit/pipeline_test.py b/tests/unit/pipeline_test.py index 7df95691..5f08eec3 100644 --- a/tests/unit/pipeline_test.py +++ b/tests/unit/pipeline_test.py @@ -35,12 +35,8 @@ def test_create_pipeline(): headers = {"x-api-key": config.TEAM_API_KEY, "Content-Type": "application/json"} ref_response = {"id": "12345"} mock.post(url, headers=headers, json=ref_response) - ref_pipeline = Pipeline( - id="12345", name="Pipeline Test", api_key=config.TEAM_API_KEY - ) - hyp_pipeline = PipelineFactory.create( - pipeline={"nodes": []}, name="Pipeline Test" - ) + ref_pipeline = Pipeline(id="12345", name="Pipeline Test", api_key=config.TEAM_API_KEY) + hyp_pipeline = PipelineFactory.create(pipeline={"nodes": []}, name="Pipeline Test") assert hyp_pipeline.id == ref_pipeline.id assert hyp_pipeline.name == ref_pipeline.name @@ -58,19 +54,19 @@ def test_create_pipeline(): ), ( 475, - "{'error': 'Billing-related error: Please ensure you have enough credits to run this pipeline. ', 'status': 'ERROR'}", + "{'error': 'Billing-related error: Please ensure you have enough credits to run this asset.', 'status': 'ERROR'}", ), ( 485, - "{'error': 'Supplier-related error: Please ensure that the selected supplier provides the pipeline you are trying to access.', 'status': 'ERROR'}", + "{'error': 'Supplier-related error: Please ensure that the selected supplier provides the asset you are trying to access.', 'status': 'ERROR'}", ), ( 495, - "{'error': 'Validation-related error: Please ensure all required fields are provided and correctly formatted.', 'status': 'ERROR'}", + "{'error': 'Validation-related error: Please verify the request payload and ensure it is correct.', 'status': 'ERROR'}", ), ( 501, - "{'error': 'Status 501: Unspecified error: An unspecified error occurred while processing your request.', 'status': 'ERROR'}", + "{'error': 'Unspecified Server Error (Status 501)', 'status': 'ERROR'}", ), ], ) @@ -107,14 +103,9 @@ def test_list_pipelines_error_response(): mock.post(url, headers=headers, json=error_response, status_code=400) with pytest.raises(Exception) as excinfo: - PipelineFactory.list( - query=query, page_number=page_number, page_size=page_size - ) + PipelineFactory.list(query=query, page_number=page_number, page_size=page_size) - assert ( - "Pipeline List Error: Failed to retrieve pipelines. Status Code: 400" - in str(excinfo.value) - ) + assert "Pipeline List Error: Failed to retrieve pipelines. Status Code: 400" in str(excinfo.value) def test_get_pipeline_error_response(): @@ -132,112 +123,7 @@ def test_get_pipeline_error_response(): with pytest.raises(Exception) as excinfo: PipelineFactory.get(pipeline_id=pipeline_id) - assert ( - "Pipeline GET Error: Failed to retrieve pipeline test-pipeline-id. Status Code: 404" - in str(excinfo.value) - ) - - -@pytest.fixture -def mock_pipeline(): - return Pipeline(id="12345", name="Pipeline Test", api_key=config.TEAM_API_KEY) - - -def test_run_async_success(mock_pipeline): - with requests_mock.Mocker() as mock: - execute_url = urljoin( - config.BACKEND_URL, f"assets/pipeline/execution/run/{mock_pipeline.id}" - ) - success_response = PipelineResponse( - status=ResponseStatus.SUCCESS, url=execute_url - ) - mock.post(execute_url, json=success_response.__dict__, status_code=200) - - response = mock_pipeline.run_async(data="input_data") - - assert isinstance(response, PipelineResponse) - assert response.status == ResponseStatus.SUCCESS - - -def test_run_sync_success(mock_pipeline): - with requests_mock.Mocker() as mock: - poll_url = urljoin( - config.BACKEND_URL, f"assets/pipeline/execution/poll/{mock_pipeline.id}" - ) - execute_url = urljoin( - config.BACKEND_URL, f"assets/pipeline/execution/run/{mock_pipeline.id}" - ) - success_response = PipelineResponse(status=ResponseStatus.SUCCESS, url=poll_url) - poll_response = PipelineResponse( - status=ResponseStatus.SUCCESS, data={"output": "poll_result"} - ) - mock.post(execute_url, json=success_response.__dict__, status_code=200) - mock.get(poll_url, json=poll_response.__dict__, status_code=200) - response = mock_pipeline.run(data="input_data") - - assert isinstance(response, PipelineResponse) - assert response.status == ResponseStatus.SUCCESS - - -def test_poll_success(mock_pipeline): - with requests_mock.Mocker() as mock: - poll_url = urljoin( - config.BACKEND_URL, f"assets/pipeline/execution/poll/{mock_pipeline.id}" - ) - poll_response = PipelineResponse( - status=ResponseStatus.SUCCESS, data={"output": "poll_result"} - ) - mock.get(poll_url, json=poll_response.__dict__, status_code=200) - - response = mock_pipeline.poll(poll_url=poll_url) - - assert isinstance(response, PipelineResponse) - assert response.status == ResponseStatus.SUCCESS - assert response.data["output"] == "poll_result" - - -@pytest.fixture -def mock_pipeline(): - return Pipeline(id="12345", name="Pipeline Test", api_key=config.TEAM_API_KEY) - - -def test_run_async_success(mock_pipeline): - with requests_mock.Mocker() as mock: - execute_url = urljoin(config.BACKEND_URL, f"assets/pipeline/execution/run/{mock_pipeline.id}") - success_response = PipelineResponse(status=ResponseStatus.SUCCESS, url=execute_url) - mock.post(execute_url, json=success_response.__dict__, status_code=200) - - response = mock_pipeline.run_async(data="input_data") - - assert isinstance(response, PipelineResponse) - assert response.status == ResponseStatus.SUCCESS - - -def test_run_sync_success(mock_pipeline): - with requests_mock.Mocker() as mock: - poll_url = urljoin(config.BACKEND_URL, f"assets/pipeline/execution/poll/{mock_pipeline.id}") - execute_url = urljoin(config.BACKEND_URL, f"assets/pipeline/execution/run/{mock_pipeline.id}") - success_response = {"status": "SUCCESS", "url": poll_url, "completed": True} - poll_response = {"status": "SUCCESS", "data": {"output": "poll_result"}, "completed": True} - mock.post(execute_url, json=success_response, status_code=200) - mock.get(poll_url, json=poll_response, status_code=200) - response = mock_pipeline.run(data="input_data") - - assert isinstance(response, PipelineResponse) - assert response.status == ResponseStatus.SUCCESS - - -def test_poll_success(mock_pipeline): - with requests_mock.Mocker() as mock: - poll_url = urljoin(config.BACKEND_URL, f"assets/pipeline/execution/poll/{mock_pipeline.id}") - poll_response = PipelineResponse(status=ResponseStatus.SUCCESS, data={"output": "poll_result"}) - mock.get(poll_url, json=poll_response.__dict__, status_code=200) - - response = mock_pipeline.poll(poll_url=poll_url) - - assert isinstance(response, PipelineResponse) - assert response.status == ResponseStatus.SUCCESS - assert response.data["output"] == "poll_result" + assert "Pipeline GET Error: Failed to retrieve pipeline test-pipeline-id. Status Code: 404" in str(excinfo.value) @pytest.fixture diff --git a/tests/unit/team_agent_test.py b/tests/unit/team_agent_test.py index 2b06043e..e84e2e34 100644 --- a/tests/unit/team_agent_test.py +++ b/tests/unit/team_agent_test.py @@ -1,7 +1,7 @@ import pytest import requests_mock from urllib.parse import urljoin -from unittest.mock import patch +from unittest.mock import patch, Mock from aixplain.enums.asset_status import AssetStatus from aixplain.factories import TeamAgentFactory @@ -181,7 +181,7 @@ def test_create_team_agent(mock_model_factory_get): "role": "Test Agent Role", "teamId": "123", "version": "1.0", - "status": "draft", + "status": "onboarded", "llmId": "6646261c6eb563165658bbb1", "pricing": {"currency": "USD", "value": 0.0}, "assets": [ @@ -502,3 +502,24 @@ def get_mock(agent_id): agent1 = next((agent for agent in team_agent.agents if agent.id == "agent1"), None) assert agent1 is not None assert agent1.tasks[0].dependencies[0].name == "Test Task 2" + + +def test_deploy_team_agent(): + # Create a mock agent with ONBOARDED status + mock_agent = Mock() + mock_agent.id = "agent-id" + mock_agent.name = "Test Agent" + mock_agent.status = AssetStatus.ONBOARDED + + # Create the team agent + team_agent = TeamAgent(id="team-agent-id", name="Test Team Agent", agents=[mock_agent], status=AssetStatus.DRAFT) + + # Mock the update method + team_agent.update = Mock() + + # Deploy the team agent + team_agent.deploy() + + # Verify that status was updated and update was called + assert team_agent.status == AssetStatus.ONBOARDED + team_agent.update.assert_called_once()