diff --git a/aixplain/factories/api_key_factory.py b/aixplain/factories/api_key_factory.py index 4ac8f00a..c719c26b 100644 --- a/aixplain/factories/api_key_factory.py +++ b/aixplain/factories/api_key_factory.py @@ -4,12 +4,20 @@ from datetime import datetime from typing import Text, List, Optional, Dict, Union from aixplain.utils.file_utils import _request_with_retry -from aixplain.modules.api_key import APIKey, APIKeyGlobalLimits, APIKeyUsageLimit +from aixplain.modules.api_key import APIKey, APIKeyLimits, APIKeyUsageLimit class APIKeyFactory: backend_url = config.BACKEND_URL + @classmethod + def get(cls, api_key: Text) -> APIKey: + """Get an API key""" + for api_key_obj in cls.list(): + if str(api_key_obj.access_key).startswith(api_key[:4]) and str(api_key_obj.access_key).endswith(api_key[-4:]): + return api_key_obj + raise Exception(f"API Key Error: API key {api_key} not found") + @classmethod def list(cls) -> List[APIKey]: """List all API keys""" @@ -30,7 +38,7 @@ def list(cls) -> List[APIKey]: name=key["name"], budget=key["budget"] if "budget" in key else None, global_limits=key["globalLimits"] if "globalLimits" in key else None, - asset_limits=key["assetLimits"] if "assetLimits" in key else [], + asset_limits=key["assetsLimits"] if "assetsLimits" in key else [], expires_at=key["expiresAt"] if "expiresAt" in key else None, access_key=key["accessKey"], is_admin=key["isAdmin"], @@ -46,8 +54,8 @@ def create( cls, name: Text, budget: int, - global_limits: Union[Dict, APIKeyGlobalLimits], - asset_limits: List[Union[Dict, APIKeyGlobalLimits]], + global_limits: Union[Dict, APIKeyLimits], + asset_limits: List[Union[Dict, APIKeyLimits]], expires_at: datetime, ) -> APIKey: """Create a new API key""" @@ -84,6 +92,7 @@ def create( @classmethod def update(cls, api_key: APIKey) -> APIKey: """Update an existing API key""" + api_key.validate() try: resp = "Unspecified error" url = f"{cls.backend_url}/sdk/api-keys/{api_key.id}" @@ -112,12 +121,10 @@ def update(cls, api_key: APIKey) -> APIKey: raise Exception(f"API Key Update Error: Failed to update API key with ID {api_key.id}. Error: {str(resp)}") @classmethod - def get_usage_limit(cls, api_key: Text = config.TEAM_API_KEY, asset_id: Optional[Text] = None) -> APIKeyUsageLimit: - """Get API key usage limit""" + def get_usage_limits(cls, api_key: Text = config.TEAM_API_KEY, asset_id: Optional[Text] = None) -> List[APIKeyUsageLimit]: + """Get API key usage limits""" try: url = f"{config.BACKEND_URL}/sdk/api-keys/usage-limits" - if asset_id is not None: - url += f"?assetId={asset_id}" headers = {"Authorization": f"Token {api_key}", "Content-Type": "application/json"} logging.info(f"Start service for GET API Key Usage - {url} - {headers}") r = _request_with_retry("GET", url, headers=headers) @@ -128,11 +135,16 @@ def get_usage_limit(cls, api_key: Text = config.TEAM_API_KEY, asset_id: Optional raise Exception(f"{message}") if 200 <= r.status_code < 300: - return APIKeyUsageLimit( - request_count=resp["requestCount"], - request_count_limit=resp["requestCountLimit"], - token_count=resp["tokenCount"], - token_count_limit=resp["tokenCountLimit"], - ) + return [ + APIKeyUsageLimit( + daily_request_count=limit["requestCount"], + daily_request_limit=limit["requestCountLimit"], + daily_token_count=limit["tokenCount"], + daily_token_limit=limit["tokenCountLimit"], + model=limit["assetId"] if "assetId" in limit else None, + ) + for limit in resp + if asset_id is None or ("assetId" in limit and limit["assetId"] == asset_id) + ] else: raise Exception(f"API Key Usage Error: Failed to get usage. Error: {str(resp)}") diff --git a/aixplain/factories/benchmark_factory.py b/aixplain/factories/benchmark_factory.py index 57d4a833..305fb5d9 100644 --- a/aixplain/factories/benchmark_factory.py +++ b/aixplain/factories/benchmark_factory.py @@ -22,10 +22,8 @@ """ import logging -from typing import Dict, List, Optional, Text +from typing import Dict, List, Text import json -import pandas as pd -from pathlib import Path from aixplain.enums.supplier import Supplier from aixplain.modules import Dataset, Metric, Model from aixplain.modules.benchmark_job import BenchmarkJob @@ -34,9 +32,8 @@ from aixplain.factories.dataset_factory import DatasetFactory from aixplain.factories.model_factory import ModelFactory from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry, save_file +from aixplain.utils.file_utils import _request_with_retry from urllib.parse import urljoin -from warnings import warn class BenchmarkFactory: @@ -117,7 +114,7 @@ def get(cls, benchmark_id: str) -> Benchmark: logging.info(f"Start service for GET Benchmark - {url} - {headers}") r = _request_with_retry("get", url, headers=headers) resp = r.json() - benchmark = cls._create_benchmark_from_response(resp) + except Exception as e: status_code = 400 if resp is not None and "statusCode" in resp: @@ -125,10 +122,17 @@ def get(cls, benchmark_id: str) -> Benchmark: message = resp["message"] message = f"Benchmark Creation: Status {status_code} - {message}" else: - message = f"Benchmark Creation: Unspecified Error" + message = "Benchmark Creation: Unspecified Error" logging.error(f"Benchmark Creation Failed: {e}") raise Exception(f"Status {status_code}: {message}") - return benchmark + if 200 <= r.status_code < 300: + benchmark = cls._create_benchmark_from_response(resp) + logging.info(f"Benchmark {benchmark_id} retrieved successfully.") + return benchmark + else: + error_message = f"Benchmark GET Error: Status {r.status_code} - {resp}" + logging.error(error_message) + raise Exception(error_message) @classmethod def get_job(cls, job_id: Text) -> BenchmarkJob: @@ -189,7 +193,7 @@ def create(cls, name: str, dataset_list: List[Dataset], model_list: List[Model], """ payload = {} try: - url = urljoin(cls.backend_url, f"sdk/benchmarks") + url = urljoin(cls.backend_url, "sdk/benchmarks") headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} payload = { "name": name, @@ -204,12 +208,19 @@ def create(cls, name: str, dataset_list: List[Dataset], model_list: List[Model], payload = json.dumps(clean_payload) r = _request_with_retry("post", url, headers=headers, data=payload) resp = r.json() - logging.info(f"Creating Benchmark Job: Status for {name}: {resp}") - return cls.get(resp["id"]) + except Exception as e: error_message = f"Creating Benchmark Job: Error in Creating Benchmark with payload {payload} : {e}" logging.error(error_message, exc_info=True) - return None + raise Exception(error_message) + + if 200 <= r.status_code < 300: + logging.info(f"Benchmark {name} created successfully.") + return cls.get(resp["id"]) + else: + error_message = f"Benchmark Creation Error: Status {r.status_code} - {resp}" + logging.error(error_message) + raise Exception(error_message) @classmethod def list_normalization_options(cls, metric: Metric, model: Model) -> List[str]: @@ -223,7 +234,7 @@ def list_normalization_options(cls, metric: Metric, model: Model) -> List[str]: List[str]: List of supported normalization options """ try: - url = urljoin(cls.backend_url, f"sdk/benchmarks/normalization-options") + url = urljoin(cls.backend_url, "sdk/benchmarks/normalization-options") if cls.aixplain_key != "": headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} else: @@ -231,13 +242,20 @@ def list_normalization_options(cls, metric: Metric, model: Model) -> List[str]: payload = json.dumps({"metricId": metric.id, "modelIds": [model.id]}) r = _request_with_retry("post", url, headers=headers, data=payload) resp = r.json() - logging.info(f"Listing Normalization Options: Status of listing options: {resp}") - normalization_options = [item["value"] for item in resp] - return normalization_options + except Exception as e: error_message = f"Listing Normalization Options: Error in getting Normalization Options: {e}" logging.error(error_message, exc_info=True) - return [] + raise Exception(error_message) + + if 200 <= r.status_code < 300: + logging.info("Listing Normalization Options: ") + normalization_options = [item["value"] for item in resp] + return normalization_options + else: + error_message = f"Error listing normalization options: Status {r.status_code} - {resp}" + logging.error(error_message) + raise Exception(error_message) @classmethod def get_benchmark_job_scores(cls, job_id): @@ -255,7 +273,8 @@ def __get_model_name(model_id): if model.version is not None: name = f"{name}({model.version})" return name + benchmarkJob = cls.get_job(job_id) scores_df = benchmarkJob.get_scores() scores_df["Model"] = scores_df["Model"].apply(lambda x: __get_model_name(x)) - return scores_df \ No newline at end of file + return scores_df diff --git a/aixplain/factories/corpus_factory.py b/aixplain/factories/corpus_factory.py index 1f81ac4d..3b9c5e4b 100644 --- a/aixplain/factories/corpus_factory.py +++ b/aixplain/factories/corpus_factory.py @@ -21,7 +21,6 @@ Corpus Factory Class """ -import aixplain.utils.config as config import aixplain.processes.data_onboarding.onboard_functions as onboard_functions import json import logging @@ -86,12 +85,12 @@ def __from_response(cls, response: Dict) -> Corpus: try: license = License(response["license"]["typeId"]) - except: + except Exception: license = None try: length = int(response["segmentsCount"]) - except: + except Exception: length = None corpus = Corpus( @@ -116,17 +115,27 @@ def get(cls, corpus_id: Text) -> Corpus: Returns: Corpus: Created 'Corpus' object """ - url = urljoin(cls.backend_url, f"sdk/corpora/{corpus_id}/overview") - if cls.aixplain_key != "": - headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} + try: + url = urljoin(cls.backend_url, f"sdk/corpora/{corpus_id}/overview") + if cls.aixplain_key != "": + headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} + else: + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + logging.info(f"Start service for GET Corpus - {url} - {headers}") + r = _request_with_retry("get", url, headers=headers) + resp = r.json() + + except Exception as e: + error_message = f"Error retrieving Corpus {corpus_id}: {str(e)}" + logging.error(error_message, exc_info=True) + raise Exception(error_message) + if 200 <= r.status_code < 300: + logging.info(f"Corpus {corpus_id} retrieved successfully.") + return cls.__from_response(resp) else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} - logging.info(f"Start service for GET Corpus - {url} - {headers}") - r = _request_with_retry("get", url, headers=headers) - resp = r.json() - if "statusCode" in resp and resp["statusCode"] == 404: - raise Exception(f"Corpus GET Error: Dataset {corpus_id} not found.") - return cls.__from_response(resp) + error_message = f"Corpus GET Error: Status {r.status_code} - {resp}" + logging.error(error_message) + raise Exception(error_message) @classmethod def create_asset_from_id(cls, corpus_id: Text) -> Corpus: @@ -168,7 +177,7 @@ def list( else: headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} - assert 0 < page_size <= 100, f"Corpus List Error: Page size must be greater than 0 and not exceed 100." + assert 0 < page_size <= 100, "Corpus List Error: Page size must be greater than 0 and not exceed 100." payload = {"pageSize": page_size, "pageNumber": page_number, "sort": [{"field": "createdAt", "dir": -1}]} if query is not None: @@ -188,26 +197,38 @@ def list( language = [language] payload["language"] = [lng.value["language"] for lng in language] - logging.info(f"Start service for POST List Corpus - {url} - {headers} - {json.dumps(payload)}") - r = _request_with_retry("post", url, headers=headers, json=payload) - resp = r.json() - corpora, page_total, total = [], 0, 0 - if "results" in resp: - results = resp["results"] - page_total = resp["pageTotal"] - total = resp["total"] - logging.info(f"Response for POST List Corpus - Page Total: {page_total} / Total: {total}") - for corpus in results: - corpus_ = cls.__from_response(corpus) - # add languages - languages = [] - for lng in corpus["languages"]: - if "dialect" not in lng: - lng["dialect"] = "" - languages.append(Language(lng)) - corpus_.kwargs["languages"] = languages - corpora.append(corpus_) - return {"results": corpora, "page_total": page_total, "page_number": page_number, "total": total} + try: + logging.info(f"Start service for POST List Corpus - {url} - {headers} - {json.dumps(payload)}") + r = _request_with_retry("post", url, headers=headers, json=payload) + resp = r.json() + + except Exception as e: + error_message = f"Error listing corpora: {str(e)}" + logging.error(error_message, exc_info=True) + raise Exception(error_message) + + if 200 <= r.status_code < 300: + corpora, page_total, total = [], 0, 0 + if "results" in resp: + results = resp["results"] + page_total = resp["pageTotal"] + total = resp["total"] + logging.info(f"Response for POST List Corpus - Page Total: {page_total} / Total: {total}") + for corpus in results: + corpus_ = cls.__from_response(corpus) + # add languages + languages = [] + for lng in corpus["languages"]: + if "dialect" not in lng: + lng["dialect"] = "" + languages.append(Language(lng)) + corpus_.kwargs["languages"] = languages + corpora.append(corpus_) + return {"results": corpora, "page_total": page_total, "page_number": page_number, "total": total} + else: + error_message = f"Corpus List Error: Status {r.status_code} - {resp}" + logging.error(error_message) + raise Exception(error_message) @classmethod def get_assets_from_page( @@ -245,7 +266,7 @@ def create( functions: List[Function] = [], privacy: Privacy = Privacy.PRIVATE, error_handler: ErrorHandler = ErrorHandler.SKIP, - api_key: Optional[Text] = None + api_key: Optional[Text] = None, ) -> Dict: """Asynchronous call to Upload a corpus to the user's dashboard. diff --git a/aixplain/factories/dataset_factory.py b/aixplain/factories/dataset_factory.py index 5e69d572..081513c0 100644 --- a/aixplain/factories/dataset_factory.py +++ b/aixplain/factories/dataset_factory.py @@ -21,7 +21,6 @@ Dataset Factory Class """ -import aixplain.utils.config as config import aixplain.processes.data_onboarding.onboard_functions as onboard_functions import json import os @@ -49,7 +48,6 @@ from typing import Any, Dict, List, Optional, Text, Union from urllib.parse import urljoin from uuid import uuid4 -from warnings import warn class DatasetFactory(AssetFactory): @@ -122,7 +120,7 @@ def __from_response(cls, response: Dict) -> Dataset: target_data_list = [data[data_id] for data_id in out["dataIds"]] data_name = target_data_list[0].name target_data[data_name] = target_data_list - except: + except Exception: pass # process function @@ -164,17 +162,27 @@ def get(cls, dataset_id: Text) -> Dataset: Returns: Dataset: Created 'Dataset' object """ - url = urljoin(cls.backend_url, f"sdk/datasets/{dataset_id}/overview") - if cls.aixplain_key != "": - headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} + try: + url = urljoin(cls.backend_url, f"sdk/datasets/{dataset_id}/overview") + if cls.aixplain_key != "": + headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} + else: + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + logging.info(f"Start service for GET Dataset - {url} - {headers}") + r = _request_with_retry("get", url, headers=headers) + resp = r.json() + + except Exception as e: + error_message = f"Error retrieving Dataset {dataset_id}: {str(e)}" + logging.error(error_message, exc_info=True) + raise Exception(error_message) + if 200 <= r.status_code < 300: + logging.info(f"Dataset {dataset_id} retrieved successfully.") + return cls.__from_response(resp) else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} - logging.info(f"Start service for GET Dataset - {url} - {headers}") - r = _request_with_retry("get", url, headers=headers) - resp = r.json() - if "statusCode" in resp and resp["statusCode"] == 404: - raise Exception(f"Dataset GET Error: Dataset {dataset_id} not found.") - return cls.__from_response(resp) + error_message = f"Dataset GET Error: Status {r.status_code} - {resp}" + logging.error(error_message) + raise Exception(error_message) @classmethod def list( @@ -211,7 +219,7 @@ def list( else: headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} - assert 0 < page_size <= 100, f"Dataset List Error: Page size must be greater than 0 and not exceed 100." + assert 0 < page_size <= 100, "Dataset List Error: Page size must be greater than 0 and not exceed 100." payload = { "pageSize": page_size, "pageNumber": page_number, @@ -245,19 +253,29 @@ def list( target_languages = [target_languages] payload["output"]["languages"] = [lng.value["language"] for lng in target_languages] - logging.info(f"Start service for POST List Dataset - {url} - {headers} - {json.dumps(payload)}") - r = _request_with_retry("post", url, headers=headers, json=payload) - resp = r.json() + try: + logging.info(f"Start service for POST List Dataset - {url} - {headers} - {json.dumps(payload)}") + r = _request_with_retry("post", url, headers=headers, json=payload) + resp = r.json() - datasets, page_total, total = [], 0, 0 - if "results" in resp: - results = resp["results"] - page_total = resp["pageTotal"] - total = resp["total"] - logging.info(f"Response for POST List Dataset - Page Total: {page_total} / Total: {total}") - for dataset in results: - datasets.append(cls.__from_response(dataset)) - return {"results": datasets, "page_total": page_total, "page_number": page_number, "total": total} + except Exception as e: + error_message = f"Error listing datasets: {str(e)}" + logging.error(error_message, exc_info=True) + raise Exception(error_message) + if 200 <= r.status_code < 300: + datasets, page_total, total = [], 0, 0 + if "results" in resp: + results = resp["results"] + page_total = resp["pageTotal"] + total = resp["total"] + logging.info(f"Response for POST List Dataset - Page Total: {page_total} / Total: {total}") + for dataset in results: + datasets.append(cls.__from_response(dataset)) + return {"results": datasets, "page_total": page_total, "page_number": page_number, "total": total} + else: + error_message = f"Dataset List Error: Status {r.status_code} - {resp}" + logging.error(error_message) + raise Exception(error_message) @classmethod def create( @@ -282,7 +300,7 @@ def create( error_handler: ErrorHandler = ErrorHandler.SKIP, s3_link: Optional[Text] = None, aws_credentials: Optional[Dict[Text, Text]] = {"AWS_ACCESS_KEY_ID": None, "AWS_SECRET_ACCESS_KEY": None}, - api_key: Optional[Text] = None + api_key: Optional[Text] = None, ) -> Dict: """Dataset Onboard diff --git a/aixplain/factories/model_factory.py b/aixplain/factories/model_factory.py index d82bdd63..209ff75d 100644 --- a/aixplain/factories/model_factory.py +++ b/aixplain/factories/model_factory.py @@ -80,6 +80,7 @@ def _create_model_from_response(cls, response: Dict) -> Model: return ModelClass( response["id"], response["name"], + description=response.get("description", ""), supplier=response["supplier"], api_key=response["api_key"], cost=response["pricing"], @@ -113,13 +114,7 @@ def get(cls, model_id: Text, api_key: Optional[Text] = None) -> Model: logging.info(f"Start service for GET Model - {url} - {headers}") r = _request_with_retry("get", url, headers=headers) resp = r.json() - # set api key - resp["api_key"] = config.TEAM_API_KEY - if api_key is not None: - resp["api_key"] = api_key - model = cls._create_model_from_response(resp) - logging.info(f"Model Creation: Model {model_id} instantiated.") - return model + except Exception: if resp is not None and "statusCode" in resp: status_code = resp["statusCode"] @@ -129,6 +124,17 @@ def get(cls, model_id: Text, api_key: Optional[Text] = None) -> Model: message = "Model Creation: Unspecified Error" logging.error(message) raise Exception(f"{message}") + if 200 <= r.status_code < 300: + resp["api_key"] = config.TEAM_API_KEY + if api_key is not None: + resp["api_key"] = api_key + model = cls._create_model_from_response(resp) + logging.info(f"Model Creation: Model {model_id} instantiated.") + return model + else: + error_message = f"Model GET Error: Failed to retrieve model {model_id}. Status Code: {r.status_code}. Error: {resp}" + logging.error(error_message) + raise Exception(error_message) @classmethod def create_asset_from_id(cls, model_id: Text) -> Model: @@ -198,20 +204,26 @@ def _get_assets_from_page( logging.info(f"Start service for POST Models Paginate - {url} - {headers} - {json.dumps(filter_params)}") r = _request_with_retry("post", url, headers=headers, json=filter_params) resp = r.json() - logging.info(f"Listing Models: Status of getting Models on Page {page_number}: {r.status_code}") - all_models = resp["items"] - model_list = [cls._create_model_from_response(model_info_json) for model_info_json in all_models] - return model_list, resp["total"] + except Exception as e: error_message = f"Listing Models: Error in getting Models on Page {page_number}: {e}" logging.error(error_message, exc_info=True) return [] + if 200 <= r.status_code < 300: + logging.info(f"Listing Models: Status of getting Models on Page {page_number}: {r.status_code}") + all_models = resp["items"] + model_list = [cls._create_model_from_response(model_info_json) for model_info_json in all_models] + return model_list, resp["total"] + else: + error_message = f"Listing Models Error: Failed to retrieve models. Status Code: {r.status_code}. Error: {resp}" + logging.error(error_message) + raise Exception(error_message) @classmethod def list( cls, + function: Function, query: Optional[Text] = "", - function: Optional[Function] = None, suppliers: Optional[Union[Supplier, List[Supplier]]] = None, source_languages: Optional[Union[Language, List[Language]]] = None, target_languages: Optional[Union[Language, List[Language]]] = None, @@ -225,7 +237,7 @@ def list( """Gets the first k given models based on the provided task and language filters Args: - function (Optional[Function], optional): function filter. Defaults to None. + function (Function): function filter. source_languages (Optional[Union[Language, List[Language]]], optional): language filter of input data. Defaults to None. target_languages (Optional[Union[Language, List[Language]]], optional): language filter of output data. Defaults to None. is_finetunable (Optional[bool], optional): can be finetuned or not. Defaults to None. @@ -237,30 +249,25 @@ def list( Returns: List[Model]: List of models based on given filters """ - try: - models, total = cls._get_assets_from_page( - query, - page_number, - page_size, - function, - suppliers, - source_languages, - target_languages, - is_finetunable, - ownership, - sort_by, - sort_order, - ) - return { - "results": models, - "page_total": min(page_size, len(models)), - "page_number": page_number, - "total": total, - } - except Exception as e: - error_message = f"Listing Models: Error in Listing Models : {e}" - logging.error(error_message, exc_info=True) - raise Exception(error_message) + models, total = cls._get_assets_from_page( + query, + page_number, + page_size, + function, + suppliers, + source_languages, + target_languages, + is_finetunable, + ownership, + sort_by, + sort_order, + ) + return { + "results": models, + "page_total": min(page_size, len(models)), + "page_number": page_number, + "total": total, + } @classmethod def list_host_machines(cls, api_key: Optional[Text] = None) -> List[Dict]: diff --git a/aixplain/factories/pipeline_factory/__init__.py b/aixplain/factories/pipeline_factory/__init__.py index cb4336fe..ef330de0 100644 --- a/aixplain/factories/pipeline_factory/__init__.py +++ b/aixplain/factories/pipeline_factory/__init__.py @@ -78,12 +78,7 @@ def get(cls, pipeline_id: Text, api_key: Optional[Text] = None) -> Pipeline: logging.info(f"Start service for GET Pipeline - {url} - {headers}") r = _request_with_retry("get", url, headers=headers) resp = r.json() - # set api key - resp["api_key"] = config.TEAM_API_KEY - if api_key is not None: - resp["api_key"] = api_key - pipeline = build_from_response(resp, load_architecture=True) - return pipeline + except Exception as e: logging.exception(e) status_code = 400 @@ -95,6 +90,20 @@ def get(cls, pipeline_id: Text, api_key: Optional[Text] = None) -> Pipeline: message = f"Pipeline Creation: Unspecified Error {e}" logging.error(message) raise Exception(f"Status {status_code}: {message}") + if 200 <= r.status_code < 300: + resp["api_key"] = config.TEAM_API_KEY + if api_key is not None: + resp["api_key"] = api_key + pipeline = build_from_response(resp, load_architecture=True) + logging.info(f"Pipeline {pipeline_id} retrieved successfully.") + return pipeline + + else: + error_message = ( + f"Pipeline GET Error: Failed to retrieve pipeline {pipeline_id}. Status Code: {r.status_code}. Error: {resp}" + ) + logging.error(error_message) + raise Exception(error_message) @classmethod def create_asset_from_id(cls, pipeline_id: Text) -> Pipeline: @@ -220,23 +229,33 @@ def list( payload["inputDataTypes"] = [data_type.value for data_type in output_data_types] logging.info(f"Start service for POST List Pipeline - {url} - {headers} - {json.dumps(payload)}") - r = _request_with_retry("post", url, headers=headers, json=payload) - resp = r.json() - - pipelines, page_total, total = [], 0, 0 - if "items" in resp: - results = resp["items"] - page_total = resp["pageTotal"] - total = resp["total"] - logging.info(f"Response for POST List Pipeline - Page Total: {page_total} / Total: {total}") - for pipeline in results: - pipelines.append(build_from_response(pipeline)) - return { - "results": pipelines, - "page_total": page_total, - "page_number": page_number, - "total": total, - } + try: + r = _request_with_retry("post", url, headers=headers, json=payload) + resp = r.json() + + except Exception as e: + error_message = f"Pipeline List Error: {str(e)}" + logging.error(error_message, exc_info=True) + raise Exception(error_message) + if 200 <= r.status_code < 300: + pipelines, page_total, total = [], 0, 0 + if "items" in resp: + results = resp["items"] + page_total = resp["pageTotal"] + total = resp["total"] + logging.info(f"Response for POST List Pipeline - Page Total: {page_total} / Total: {total}") + for pipeline in results: + pipelines.append(build_from_response(pipeline)) + return { + "results": pipelines, + "page_total": page_total, + "page_number": page_number, + "total": total, + } + else: + error_message = f"Pipeline List Error: Failed to retrieve pipelines. Status Code: {r.status_code}. Error: {resp}" + logging.error(error_message) + raise Exception(error_message) @classmethod def init(cls, name: Text, api_key: Optional[Text] = None) -> Pipeline: diff --git a/aixplain/factories/pipeline_factory/utils.py b/aixplain/factories/pipeline_factory/utils.py index 9584863f..7911c370 100644 --- a/aixplain/factories/pipeline_factory/utils.py +++ b/aixplain/factories/pipeline_factory/utils.py @@ -86,8 +86,23 @@ def build_from_response(response: Dict, load_architecture: bool = False) -> Pipe node.label = node_json["label"] pipeline.add_node(node) + # Decision nodes' output parameters are defined based on their + # input parameters linked. So here we have to make sure that + # decision nodes (having passthrough parameter) should be first + # linked + link_jsons = response["links"][:] + decision_links = [] + for link_json in link_jsons: + for pm in link_json["paramMapping"]: + if pm["to"] == "passthrough": + decision_link_index = link_jsons.index(link_json) + decision_link = link_jsons.pop(decision_link_index) + decision_links.append(decision_link) + + link_jsons = decision_links + link_jsons + # instantiating links - for link_json in response["links"]: + for link_json in link_jsons: for param_mapping in link_json["paramMapping"]: link = Link( from_node=pipeline.get_node(link_json["from"]), diff --git a/aixplain/modules/__init__.py b/aixplain/modules/__init__.py index d49e29d4..4432e1ad 100644 --- a/aixplain/modules/__init__.py +++ b/aixplain/modules/__init__.py @@ -36,4 +36,4 @@ from .agent import Agent from .agent.tool import Tool from .team_agent import TeamAgent -from .api_key import APIKey, APIKeyGlobalLimits, APIKeyUsageLimit +from .api_key import APIKey, APIKeyLimits, APIKeyUsageLimit diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index 546ea4d8..41bb0a2e 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -105,6 +105,8 @@ def run( parameters: Dict = {}, wait_time: float = 0.5, content: Optional[Union[Dict[Text, Text], List[Text]]] = None, + max_tokens: int = 2048, + max_iterations: int = 10, ) -> Dict: """Runs an agent call. @@ -118,6 +120,8 @@ def run( parameters (Dict, optional): optional parameters to the model. Defaults to "{}". wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5. content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None. + max_tokens (int, optional): maximum number of tokens which can be generated by the agent. Defaults to 2048. + max_iterations (int, optional): maximum number of iterations between the agent and the tools. Defaults to 10. Returns: Dict: parsed output from model @@ -132,6 +136,8 @@ def run( name=name, parameters=parameters, content=content, + max_tokens=max_tokens, + max_iterations=max_iterations, ) if response["status"] == "FAILED": end = time.time() @@ -156,6 +162,8 @@ def run_async( name: Text = "model_process", parameters: Dict = {}, content: Optional[Union[Dict[Text, Text], List[Text]]] = None, + max_tokens: int = 2048, + max_iterations: int = 10, ) -> Dict: """Runs asynchronously an agent call. @@ -167,6 +175,8 @@ def run_async( name (Text, optional): ID given to a call. Defaults to "model_process". parameters (Dict, optional): optional parameters to the model. Defaults to "{}". content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None. + max_tokens (int, optional): maximum number of tokens which can be generated by the agent. Defaults to 2048. + max_iterations (int, optional): maximum number of iterations between the agent and the tools. Defaults to 10. Returns: dict: polling URL in response @@ -205,6 +215,12 @@ def run_async( headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} payload = {"id": self.id, "query": FileFactory.to_link(query), "sessionId": session_id, "history": history} + parameters.update( + { + "max_tokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens, + "max_iterations": parameters["max_iterations"] if "max_iterations" in parameters else max_iterations, + } + ) payload.update(parameters) payload = json.dumps(payload) @@ -236,6 +252,10 @@ def delete(self) -> None: if r.status_code != 200: raise Exception() except Exception: - message = f"Agent Deletion Error (HTTP {r.status_code}): Make sure the agent exists and you are the owner." + try: + response_json = r.json() + message = f"Agent Deletion Error (HTTP {r.status_code}): {response_json.get('message')}." + except ValueError: + message = f"Agent Deletion Error (HTTP {r.status_code}): There was an error in deleting the agent." logging.error(message) - raise Exception(f"{message}") + raise Exception(message) \ No newline at end of file diff --git a/aixplain/modules/api_key.py b/aixplain/modules/api_key.py index 886b0dab..ae774c23 100644 --- a/aixplain/modules/api_key.py +++ b/aixplain/modules/api_key.py @@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Text, Union -class APIKeyGlobalLimits: +class APIKeyLimits: def __init__( self, token_per_minute: int, @@ -27,19 +27,31 @@ def __init__( class APIKeyUsageLimit: - def __init__(self, request_count: int, request_count_limit: int, token_count: int, token_count_limit: int): - """Get the usage limits of an API key + def __init__( + self, + daily_request_count: int, + daily_request_limit: int, + daily_token_count: int, + daily_token_limit: int, + model: Optional[Union[Text, Model]] = None, + ): + """Get the usage limits of an API key globally (model equals to None) or for a specific model. Args: - request_count (int): number of requests made - request_count_limit (int): limit of requests - token_count (int): number of tokens used - token_count_limit (int): limit of tokens + daily_request_count (int): number of requests made + daily_request_limit (int): limit of requests + daily_token_count (int): number of tokens used + daily_token_limit (int): limit of tokens + model (Optional[Union[Text, Model]], optional): Model which the limits apply. Defaults to None. """ - self.request_count = request_count - self.request_count_limit = request_count_limit - self.token_count = token_count - self.token_count_limit = token_count_limit + self.daily_request_count = daily_request_count + self.daily_request_limit = daily_request_limit + self.daily_token_count = daily_token_count + self.daily_token_limit = daily_token_limit + if model is not None and isinstance(model, str): + from aixplain.factories import ModelFactory + + self.model = ModelFactory.get(model) class APIKey: @@ -48,8 +60,8 @@ def __init__( name: Text, expires_at: Optional[Union[datetime, Text]] = None, budget: Optional[float] = None, - asset_limits: List[APIKeyGlobalLimits] = [], - global_limits: Optional[Union[Dict, APIKeyGlobalLimits]] = None, + asset_limits: List[APIKeyLimits] = [], + global_limits: Optional[Union[Dict, APIKeyLimits]] = None, id: int = "", access_key: Optional[Text] = None, is_admin: bool = False, @@ -59,7 +71,7 @@ def __init__( self.budget = budget self.global_limits = global_limits if global_limits is not None and isinstance(global_limits, dict): - self.global_limits = APIKeyGlobalLimits( + self.global_limits = APIKeyLimits( token_per_minute=global_limits["tpm"], token_per_day=global_limits["tpd"], request_per_minute=global_limits["rpm"], @@ -68,7 +80,7 @@ def __init__( self.asset_limits = asset_limits for i, asset_limit in enumerate(self.asset_limits): if isinstance(asset_limit, dict): - self.asset_limits[i] = APIKeyGlobalLimits( + self.asset_limits[i] = APIKeyLimits( token_per_minute=asset_limit["tpm"], token_per_day=asset_limit["tpd"], request_per_minute=asset_limit["rpm"], @@ -110,7 +122,7 @@ def to_dict(self) -> Dict: "id": self.id, "name": self.name, "budget": self.budget, - "assetLimits": [], + "assetsLimits": [], "expiresAt": self.expires_at, } @@ -126,7 +138,7 @@ def to_dict(self) -> Dict: } for i, asset_limit in enumerate(self.asset_limits): - payload["assetLimits"].append( + payload["assetsLimits"].append( { "tpm": asset_limit.token_per_minute, "tpd": asset_limit.token_per_day, @@ -157,8 +169,6 @@ def get_usage(self, asset_id: Optional[Text] = None) -> APIKeyUsageLimit: url = f"{config.BACKEND_URL}/sdk/api-keys/{self.id}/usage-limits" headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} logging.info(f"Start service for GET API Key Usage - {url} - {headers}") - if asset_id is not None: - url += f"?assetId={asset_id}" r = _request_with_retry("GET", url, headers=headers) resp = r.json() except Exception: @@ -167,11 +177,48 @@ def get_usage(self, asset_id: Optional[Text] = None) -> APIKeyUsageLimit: raise Exception(f"{message}") if 200 <= r.status_code < 300: - return APIKeyUsageLimit( - request_count=resp["requestCount"], - request_count_limit=resp["requestCountLimit"], - token_count=resp["tokenCount"], - token_count_limit=resp["tokenCountLimit"], - ) + return [ + APIKeyUsageLimit( + daily_request_count=limit["requestCount"], + daily_request_limit=limit["requestCountLimit"], + daily_token_count=limit["tokenCount"], + daily_token_limit=limit["tokenCountLimit"], + model=limit["assetId"] if "assetId" in limit else None, + ) + for limit in resp + if asset_id is None or ("assetId" in limit and limit["assetId"] == asset_id) + ] else: raise Exception(f"API Key Usage Error: Failed to get usage. Error: {str(resp)}") + + def __set_limit(self, limit: int, model: Optional[Union[Text, Model]], limit_type: Text) -> None: + """Set a limit for an API key""" + if model is None: + setattr(self.global_limits, limit_type, limit) + else: + if isinstance(model, Model): + model = model.id + is_found = False + for i, asset_limit in enumerate(self.asset_limits): + if asset_limit.model.id == model: + setattr(self.asset_limits[i], limit_type, limit) + is_found = True + break + if is_found is False: + raise Exception(f"Limit for Model {model} not found in the API key.") + + def set_token_per_day(self, token_per_day: int, model: Optional[Union[Text, Model]] = None) -> None: + """Set the token per day limit of an API key""" + self.__set_limit(token_per_day, model, "token_per_day") + + def set_token_per_minute(self, token_per_minute: int, model: Optional[Union[Text, Model]] = None) -> None: + """Set the token per minute limit of an API key""" + self.__set_limit(token_per_minute, model, "token_per_minute") + + def set_request_per_day(self, request_per_day: int, model: Optional[Union[Text, Model]] = None) -> None: + """Set the request per day limit of an API key""" + self.__set_limit(request_per_day, model, "request_per_day") + + def set_request_per_minute(self, request_per_minute: int, model: Optional[Union[Text, Model]] = None) -> None: + """Set the request per minute limit of an API key""" + self.__set_limit(request_per_minute, model, "request_per_minute") diff --git a/aixplain/modules/model/__init__.py b/aixplain/modules/model/__init__.py index 2e9445b5..765960d4 100644 --- a/aixplain/modules/model/__init__.py +++ b/aixplain/modules/model/__init__.py @@ -21,11 +21,11 @@ Model Class """ import time -import json import logging import traceback from aixplain.enums import Supplier, Function from aixplain.modules.asset import Asset +from aixplain.modules.model.utils import build_payload, call_run_endpoint from aixplain.utils import config from urllib.parse import urljoin from aixplain.utils.file_utils import _request_with_retry @@ -149,7 +149,7 @@ def sync_poll(self, poll_url: Text, name: Text = "model_process", wait_time: flo logging.error(f"Polling for Model: polling for {name}: {e}") break if response_body["completed"] is True: - logging.info(f"Polling for Model: Final status of polling for {name}: {response_body}") + logging.debug(f"Polling for Model: Final status of polling for {name}: {response_body}") else: response_body["status"] = "FAILED" logging.error( @@ -188,7 +188,7 @@ def run( data: Union[Text, Dict], name: Text = "model_process", timeout: float = 300, - parameters: Dict = {}, + parameters: Optional[Dict] = {}, wait_time: float = 0.5, ) -> Dict: """Runs a model call. @@ -204,23 +204,23 @@ def run( Dict: parsed output from model """ start = time.time() - try: - response = self.run_async(data, name=name, parameters=parameters) - if response["status"] == "FAILED": + payload = build_payload(data=data, parameters=parameters) + url = f"{self.url}/{self.id}".replace("api/v1/execute", "api/v2/execute") + logging.debug(f"Model Run Sync: Start service for {name} - {url}") + response = call_run_endpoint(payload=payload, url=url, api_key=self.api_key) + if response["status"] == "IN_PROGRESS": + try: + poll_url = response["url"] end = time.time() - response["elapsed_time"] = end - start - return response - poll_url = response["url"] - end = time.time() - response = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time) - return response - except Exception as e: - msg = f"Error in request for {name} - {traceback.format_exc()}" - logging.error(f"Model Run: Error in running for {name}: {e}") - end = time.time() - return {"status": "FAILED", "error": msg, "elapsed_time": end - start} + response = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time) + except Exception as e: + msg = f"Error in request for {name} - {traceback.format_exc()}" + logging.error(f"Model Run: Error in running for {name}: {e}") + end = time.time() + response = {"status": "FAILED", "error": msg, "elapsed_time": end - start} + return response - def run_async(self, data: Union[Text, Dict], name: Text = "model_process", parameters: Dict = {}) -> Dict: + def run_async(self, data: Union[Text, Dict], name: Text = "model_process", parameters: Optional[Dict] = {}) -> Dict: """Runs asynchronously a model call. Args: @@ -231,59 +231,10 @@ def run_async(self, data: Union[Text, Dict], name: Text = "model_process", param Returns: dict: polling URL in response """ - headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} - from aixplain.factories.file_factory import FileFactory - - data = FileFactory.to_link(data) - if isinstance(data, dict): - payload = data - else: - try: - payload = json.loads(data) - if isinstance(payload, dict) is False: - if isinstance(payload, int) is True or isinstance(payload, float) is True: - payload = str(payload) - payload = {"data": payload} - except Exception: - payload = {"data": data} - payload.update(parameters) - payload = json.dumps(payload) - - call_url = f"{self.url}/{self.id}" - r = _request_with_retry("post", call_url, headers=headers, data=payload) - logging.info(f"Model Run Async: Start service for {name} - {self.url} - {payload} - {headers}") - - resp = None - try: - if 200 <= r.status_code < 300: - resp = r.json() - logging.info(f"Result of request for {name} - {r.status_code} - {resp}") - poll_url = resp["data"] - response = {"status": "IN_PROGRESS", "url": poll_url} - else: - if r.status_code == 401: - error = "Unauthorized API key: Please verify the spelling of the API key and its current validity." - elif 460 <= r.status_code < 470: - error = "Subscription-related error: Please ensure that your subscription is active and has not expired." - elif 470 <= r.status_code < 480: - error = "Billing-related error: Please ensure you have enough credits to run this model. " - elif 480 <= r.status_code < 490: - error = "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access." - elif 490 <= r.status_code < 500: - error = "Validation-related error: Please ensure all required fields are provided and correctly formatted." - else: - status_code = str(r.status_code) - error = ( - f"Status {status_code}: Unspecified error: An unspecified error occurred while processing your request." - ) - response = {"status": "FAILED", "error_message": error} - logging.error(f"Error in request for {name} - {r.status_code}: {error}") - except Exception: - response = {"status": "FAILED"} - msg = f"Error in request for {name} - {traceback.format_exc()}" - logging.error(f"Model Run Async: Error in running for {name}: {resp}") - if resp is not None: - response["error"] = msg + url = f"{self.url}/{self.id}" + logging.debug(f"Model Run Async: Start service for {name} - {url}") + payload = build_payload(data=data, parameters=parameters) + response = call_run_endpoint(payload=payload, url=url, api_key=self.api_key) return response def check_finetune_status(self, after_epoch: Optional[int] = None): diff --git a/aixplain/modules/model/llm_model.py b/aixplain/modules/model/llm_model.py index c595d207..f48a3068 100644 --- a/aixplain/modules/model/llm_model.py +++ b/aixplain/modules/model/llm_model.py @@ -21,13 +21,12 @@ Large Language Model Class """ import time -import json import logging import traceback from aixplain.enums import Function, Supplier from aixplain.modules.model import Model +from aixplain.modules.model.utils import build_payload, call_run_endpoint from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry from typing import Union, Optional, List, Text, Dict @@ -103,7 +102,7 @@ def run( top_p: float = 1.0, name: Text = "model_process", timeout: float = 300, - parameters: Dict = {}, + parameters: Optional[Dict] = {}, wait_time: float = 0.5, ) -> Dict: """Synchronously running a Large Language Model (LLM) model. @@ -125,31 +124,31 @@ def run( Dict: parsed output from model """ start = time.time() - try: - response = self.run_async( - data, - name=name, - temperature=temperature, - max_tokens=max_tokens, - top_p=top_p, - context=context, - prompt=prompt, - history=history, - parameters=parameters, - ) - if response["status"] == "FAILED": + parameters.update( + { + "context": parameters["context"] if "context" in parameters else context, + "prompt": parameters["prompt"] if "prompt" in parameters else prompt, + "history": parameters["history"] if "history" in parameters else history, + "temperature": parameters["temperature"] if "temperature" in parameters else temperature, + "max_tokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens, + "top_p": parameters["top_p"] if "top_p" in parameters else top_p, + } + ) + payload = build_payload(data=data, parameters=parameters) + url = f"{self.url}/{self.id}".replace("/api/v1/execute", "/api/v2/execute") + logging.debug(f"Model Run Sync: Start service for {name} - {url}") + response = call_run_endpoint(payload=payload, url=url, api_key=self.api_key) + if response["status"] == "IN_PROGRESS": + try: + poll_url = response["url"] end = time.time() - response["elapsed_time"] = end - start - return response - poll_url = response["url"] - end = time.time() - response = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time) - return response - except Exception as e: - msg = f"Error in request for {name} - {traceback.format_exc()}" - logging.error(f"LLM Run: Error in running for {name}: {e}") - end = time.time() - return {"status": "FAILED", "error": msg, "elapsed_time": end - start} + response = self.sync_poll(poll_url, name=name, timeout=timeout, wait_time=wait_time) + except Exception as e: + msg = f"Error in request for {name} - {traceback.format_exc()}" + logging.error(f"Model Run: Error in running for {name}: {e}") + end = time.time() + response = {"status": "FAILED", "error": msg, "elapsed_time": end - start} + return response def run_async( self, @@ -161,7 +160,7 @@ def run_async( max_tokens: int = 128, top_p: float = 1.0, name: Text = "model_process", - parameters: Dict = {}, + parameters: Optional[Dict] = {}, ) -> Dict: """Runs asynchronously a model call. @@ -179,66 +178,18 @@ def run_async( Returns: dict: polling URL in response """ - headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} - - from aixplain.factories.file_factory import FileFactory - - data = FileFactory.to_link(data) - if isinstance(data, dict): - payload = data - else: - try: - payload = json.loads(data) - if isinstance(payload, dict) is False: - if isinstance(payload, int) is True or isinstance(payload, float) is True: - payload = str(payload) - payload = {"data": payload} - except Exception: - payload = {"data": data} + url = f"{self.url}/{self.id}" + logging.debug(f"Model Run Async: Start service for {name} - {url}") parameters.update( { - "context": payload["context"] if "context" in payload else context, - "prompt": payload["prompt"] if "prompt" in payload else prompt, - "history": payload["history"] if "history" in payload else history, - "temperature": payload["temperature"] if "temperature" in payload else temperature, - "max_tokens": payload["max_tokens"] if "max_tokens" in payload else max_tokens, - "top_p": payload["top_p"] if "top_p" in payload else top_p, + "context": parameters["context"] if "context" in parameters else context, + "prompt": parameters["prompt"] if "prompt" in parameters else prompt, + "history": parameters["history"] if "history" in parameters else history, + "temperature": parameters["temperature"] if "temperature" in parameters else temperature, + "max_tokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens, + "top_p": parameters["top_p"] if "top_p" in parameters else top_p, } ) - payload.update(parameters) - payload = json.dumps(payload) - - call_url = f"{self.url}/{self.id}" - r = _request_with_retry("post", call_url, headers=headers, data=payload) - logging.info(f"Model Run Async: Start service for {name} - {self.url} - {payload} - {headers}") - - resp = None - try: - if 200 <= r.status_code < 300: - resp = r.json() - logging.info(f"Result of request for {name} - {r.status_code} - {resp}") - poll_url = resp["data"] - response = {"status": "IN_PROGRESS", "url": poll_url} - else: - if r.status_code == 401: - error = "Unauthorized API key: Please verify the spelling of the API key and its current validity." - elif 460 <= r.status_code < 470: - error = "Subscription-related error: Please ensure that your subscription is active and has not expired." - elif 470 <= r.status_code < 480: - error = "Billing-related error: Please ensure you have enough credits to run this model. " - elif 480 <= r.status_code < 490: - error = "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access." - elif 490 <= r.status_code < 500: - error = "Validation-related error: Please ensure all required fields are provided and correctly formatted." - else: - status_code = str(r.status_code) - error = f"Status {status_code}: Unspecified error: An unspecified error occurred while processing your request." - response = {"status": "FAILED", "error_message": error} - logging.error(f"Error in request for {name} - {r.status_code}: {error}") - except Exception: - response = {"status": "FAILED"} - msg = f"Error in request for {name} - {traceback.format_exc()}" - logging.error(f"Model Run Async: Error in running for {name}: {resp}") - if resp is not None: - response["error"] = msg + payload = build_payload(data=data, parameters=parameters) + response = call_run_endpoint(payload=payload, url=url, api_key=self.api_key) return response diff --git a/aixplain/modules/model/utils.py b/aixplain/modules/model/utils.py new file mode 100644 index 00000000..2235b35a --- /dev/null +++ b/aixplain/modules/model/utils.py @@ -0,0 +1,76 @@ +__author__ = "thiagocastroferreira" + +import json +import logging +from aixplain.utils.file_utils import _request_with_retry +from typing import Dict, Text, Union + + +def build_payload(data: Union[Text, Dict], parameters: Dict = {}): + from aixplain.factories import FileFactory + + data = FileFactory.to_link(data) + if isinstance(data, dict): + payload = data + else: + try: + payload = json.loads(data) + if isinstance(payload, dict) is False: + if isinstance(payload, int) is True or isinstance(payload, float) is True: + payload = str(payload) + payload = {"data": payload} + except Exception: + payload = {"data": data} + payload.update(parameters) + payload = json.dumps(payload) + return payload + + +def call_run_endpoint(url: Text, api_key: Text, payload: Dict) -> Dict: + headers = {"x-api-key": api_key, "Content-Type": "application/json"} + + resp = "unspecified error" + try: + r = _request_with_retry("post", url, headers=headers, data=payload) + resp = r.json() + except Exception as e: + logging.error(f"Error in request: {e}") + response = { + "status": "FAILED", + "completed": True, + "error_message": "Model Run: An error occurred while processing your request.", + } + + if 200 <= r.status_code < 300: + logging.info(f"Result of request: {r.status_code} - {resp}") + status = resp.get("status", "IN_PROGRESS") + data = resp.get("data", None) + if status == "IN_PROGRESS": + if data is not None: + response = {"status": status, "url": data, "completed": True} + else: + response = { + "status": "FAILED", + "completed": True, + "error_message": "Model Run: An error occurred while processing your request.", + } + else: + response = resp + else: + resp = resp["error"] if isinstance(resp, dict) and "error" in resp else resp + if r.status_code == 401: + error = f"Unauthorized API key: Please verify the spelling of the API key and its current validity. Details: {resp}" + elif 460 <= r.status_code < 470: + error = f"Subscription-related error: Please ensure that your subscription is active and has not expired. Details: {resp}" + elif 470 <= r.status_code < 480: + error = f"Billing-related error: Please ensure you have enough credits to run this model. Details: {resp}" + elif 480 <= r.status_code < 490: + error = f"Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access. Details: {resp}" + elif 490 <= r.status_code < 500: + error = f"{resp}" + else: + status_code = str(r.status_code) + error = f"Status {status_code} - Unspecified error: {resp}" + response = {"status": "FAILED", "error_message": error, "completed": True} + logging.error(f"Error in request: {r.status_code}: {error}") + return response diff --git a/aixplain/modules/pipeline/designer/base.py b/aixplain/modules/pipeline/designer/base.py index 76e6196d..49c68463 100644 --- a/aixplain/modules/pipeline/designer/base.py +++ b/aixplain/modules/pipeline/designer/base.py @@ -142,14 +142,34 @@ def __init__( pipeline: "DesignerPipeline" = None, ): - assert from_param in from_node.outputs, "Invalid from param" - assert to_param in to_node.inputs, "Invalid to param" - if isinstance(from_param, Param): from_param = from_param.code if isinstance(to_param, Param): to_param = to_param.code + assert from_param in from_node.outputs, ( + "Invalid from param. " + "Make sure all input params are already linked accordingly" + ) + + fp_instance = from_node.outputs[from_param] + from .nodes import Decision + + if ( + isinstance(to_node, Decision) + and to_param == to_node.inputs.passthrough.code + ): + if from_param not in to_node.outputs: + to_node.outputs.create_param( + from_param, + fp_instance.data_type, + is_required=fp_instance.is_required, + ) + else: + to_node.outputs[from_param].data_type = fp_instance.data_type + + assert to_param in to_node.inputs, "Invalid to param" + self.from_node = from_node self.to_node = to_node self.from_param = from_param @@ -233,9 +253,7 @@ def __init__(self, node: "Node", *args, **kwargs): def add_param(self, param: Param) -> None: # check if param already registered if param in self: - raise ValueError( - f"Parameter with code '{param.code}' already exists." - ) + raise ValueError(f"Parameter with code '{param.code}' already exists.") self._params.append(param) # also set attribute on the node dynamically if there's no # any attribute with the same name @@ -353,9 +371,7 @@ def attach_to(self, pipeline: "DesignerPipeline"): :param pipeline: the pipeline """ assert not self.pipeline, "Node already attached to a pipeline" - assert ( - self not in pipeline.nodes - ), "Node already attached to a pipeline" + assert self not in pipeline.nodes, "Node already attached to a pipeline" assert self.type, "Node type not set" self.pipeline = pipeline diff --git a/aixplain/modules/pipeline/designer/nodes.py b/aixplain/modules/pipeline/designer/nodes.py index a6879e04..70ff302f 100644 --- a/aixplain/modules/pipeline/designer/nodes.py +++ b/aixplain/modules/pipeline/designer/nodes.py @@ -288,11 +288,15 @@ def __init__(self, value: DataType, path: List[Union[Node, int]], operation: Ope self.operation = operation self.type = type - if not self.path: - raise ValueError("Path is not valid, should be a list of nodes") + # Path can be an empty list in case the user has a valid case + # if not self.path: + # raise ValueError("Path is not valid, should be a list of nodes") # convert nodes to node numbers if they are nodes - self.path = [node.number if isinstance(node, Node) else node for node in self.path] + self.path = [ + node.number if isinstance(node, Node) else node + for node in self.path + ] def serialize(self) -> dict: return { diff --git a/aixplain/modules/team_agent/__init__.py b/aixplain/modules/team_agent/__init__.py index 420fc23a..86321489 100644 --- a/aixplain/modules/team_agent/__init__.py +++ b/aixplain/modules/team_agent/__init__.py @@ -108,6 +108,8 @@ def run( parameters: Dict = {}, wait_time: float = 0.5, content: Optional[Union[Dict[Text, Text], List[Text]]] = None, + max_tokens: int = 2048, + max_iterations: int = 30, ) -> Dict: """Runs a team agent call. @@ -121,7 +123,8 @@ def run( parameters (Dict, optional): optional parameters to the model. Defaults to "{}". wait_time (float, optional): wait time in seconds between polling calls. Defaults to 0.5. content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None. - + max_tokens (int, optional): maximum number of tokens which can be generated by the agents. Defaults to 2048. + max_iterations (int, optional): maximum number of iterations between the agents. Defaults to 30. Returns: Dict: parsed output from model """ @@ -135,6 +138,8 @@ def run( name=name, parameters=parameters, content=content, + max_tokens=max_tokens, + max_iterations=max_iterations, ) if response["status"] == "FAILED": end = time.time() @@ -159,6 +164,8 @@ def run_async( name: Text = "model_process", parameters: Dict = {}, content: Optional[Union[Dict[Text, Text], List[Text]]] = None, + max_tokens: int = 2048, + max_iterations: int = 30, ) -> Dict: """Runs asynchronously a Team Agent call. @@ -170,6 +177,8 @@ def run_async( name (Text, optional): ID given to a call. Defaults to "model_process". parameters (Dict, optional): optional parameters to the model. Defaults to "{}". content (Union[Dict[Text, Text], List[Text]], optional): Content inputs to be processed according to the query. Defaults to None. + max_tokens (int, optional): maximum number of tokens which can be generated by the agents. Defaults to 2048. + max_iterations (int, optional): maximum number of iterations between the agents. Defaults to 30. Returns: dict: polling URL in response @@ -208,6 +217,12 @@ def run_async( headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} payload = {"id": self.id, "query": FileFactory.to_link(query), "sessionId": session_id, "history": history} + parameters.update( + { + "max_tokens": parameters["max_tokens"] if "max_tokens" in parameters else max_tokens, + "max_iterations": parameters["max_iterations"] if "max_iterations" in parameters else max_iterations, + } + ) payload.update(parameters) payload = json.dumps(payload) diff --git a/aixplain/utils/config.py b/aixplain/utils/config.py index 3bb0eb09..59805c60 100644 --- a/aixplain/utils/config.py +++ b/aixplain/utils/config.py @@ -19,11 +19,11 @@ logger = logging.getLogger(__name__) BACKEND_URL = os.getenv("BACKEND_URL", "https://platform-api.aixplain.com") -MODELS_RUN_URL = os.getenv("MODELS_RUN_URL", "https://models.aixplain.com/api/v1/execute") +MODELS_RUN_URL = os.getenv("MODELS_RUN_URL", "https://models.aixplain.com") # GET THE API KEY FROM CMD TEAM_API_KEY = os.getenv("TEAM_API_KEY", "") AIXPLAIN_API_KEY = os.getenv("AIXPLAIN_API_KEY", "") PIPELINE_API_KEY = os.getenv("PIPELINE_API_KEY", "") MODEL_API_KEY = os.getenv("MODEL_API_KEY", "") LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO") -HF_TOKEN = os.getenv("HF_TOKEN", "") \ No newline at end of file +HF_TOKEN = os.getenv("HF_TOKEN", "") diff --git a/pyproject.toml b/pyproject.toml index e0df02a2..1656947a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ namespaces = true [project] name = "aiXplain" -version = "0.2.21rc0" +version = "0.2.21" description = "aiXplain SDK adds AI functions to software." readme = "README.md" requires-python = ">=3.5, <4" diff --git a/tests/functional/apikey/test_api.py b/tests/functional/apikey/test_api.py index 80b75189..221a58fb 100644 --- a/tests/functional/apikey/test_api.py +++ b/tests/functional/apikey/test_api.py @@ -1,5 +1,5 @@ from aixplain.factories.api_key_factory import APIKeyFactory -from aixplain.modules import APIKey, APIKeyGlobalLimits, APIKeyUsageLimit +from aixplain.modules import APIKey, APIKeyLimits, APIKeyUsageLimit from datetime import datetime import json import pytest @@ -16,7 +16,7 @@ def test_create_api_key_from_json(): api_key = APIKeyFactory.create( name=api_key_data["name"], asset_limits=[ - APIKeyGlobalLimits( + APIKeyLimits( model=api_key_data["asset_limits"][0]["model"], token_per_minute=api_key_data["asset_limits"][0]["token_per_minute"], token_per_day=api_key_data["asset_limits"][0]["token_per_day"], @@ -24,7 +24,7 @@ def test_create_api_key_from_json(): request_per_minute=api_key_data["asset_limits"][0]["request_per_minute"], ) ], - global_limits=APIKeyGlobalLimits( + global_limits=APIKeyLimits( token_per_minute=api_key_data["global_limits"]["token_per_minute"], token_per_day=api_key_data["global_limits"]["token_per_day"], request_per_day=api_key_data["global_limits"]["request_per_day"], @@ -60,8 +60,8 @@ def test_create_api_key_from_dict(): api_key_name = "Test API Key" api_key = APIKeyFactory.create( name=api_key_name, - asset_limits=[APIKeyGlobalLimits(**limit) for limit in api_key_dict["asset_limits"]], - global_limits=APIKeyGlobalLimits(**api_key_dict["global_limits"]), + asset_limits=[APIKeyLimits(**limit) for limit in api_key_dict["asset_limits"]], + global_limits=APIKeyLimits(**api_key_dict["global_limits"]), budget=api_key_dict["budget"], expires_at=datetime.strptime(api_key_dict["expires_at"], "%Y-%m-%dT%H:%M:%SZ"), ) @@ -92,8 +92,8 @@ def test_create_update_api_key_from_dict(): api_key_name = "Test API Key" api_key = APIKeyFactory.create( name=api_key_name, - asset_limits=[APIKeyGlobalLimits(**limit) for limit in api_key_dict["asset_limits"]], - global_limits=APIKeyGlobalLimits(**api_key_dict["global_limits"]), + asset_limits=[APIKeyLimits(**limit) for limit in api_key_dict["asset_limits"]], + global_limits=APIKeyLimits(**api_key_dict["global_limits"]), budget=api_key_dict["budget"], expires_at=datetime.strptime(api_key_dict["expires_at"], "%Y-%m-%dT%H:%M:%SZ"), ) @@ -102,6 +102,11 @@ def test_create_update_api_key_from_dict(): assert api_key.id != "" assert api_key.name == api_key_name + api_key_ = APIKeyFactory.get(api_key=api_key.access_key) + assert isinstance(api_key_, APIKey) + assert api_key_.id != "" + assert api_key_.name == api_key_name + api_key.global_limits.token_per_day = 222 api_key.global_limits.token_per_minute = 222 api_key.global_limits.request_per_day = 222 @@ -134,7 +139,65 @@ def test_list_api_keys(): if api_key.is_admin is False: usage = api_key.get_usage() - assert isinstance(usage, APIKeyUsageLimit) + assert isinstance(usage, list) + if len(usage) > 0: + assert isinstance(usage[0], APIKeyUsageLimit) + + +def test_list_update_api_keys(): + api_keys = APIKeyFactory.list() + assert isinstance(api_keys, list) + + for api_key in api_keys: + assert isinstance(api_key, APIKey) + assert api_key.id != "" + + from random import randint + + number = randint(0, 10000) + if api_key.global_limits is None: + api_key.global_limits = APIKeyLimits( + token_per_minute=number, + token_per_day=number, + request_per_day=number, + request_per_minute=number, + ) + else: + api_key.global_limits.token_per_day = number + api_key.global_limits.token_per_minute = number + api_key.global_limits.request_per_day = number + api_key.global_limits.request_per_minute = number + + if api_key.asset_limits is None: + api_key.asset_limits = [] + + if len(api_key.asset_limits) == 0: + api_key.asset_limits.append( + APIKeyLimits( + model="640b517694bf816d35a59125", + token_per_minute=number, + token_per_day=number, + request_per_day=number, + request_per_minute=number, + ) + ) + else: + api_key.asset_limits[0].request_per_day = number + api_key.asset_limits[0].request_per_minute = number + api_key.asset_limits[0].token_per_day = number + api_key.asset_limits[0].token_per_minute = number + api_key = APIKeyFactory.update(api_key) + + assert api_key.global_limits.token_per_day == number + assert api_key.global_limits.token_per_minute == number + assert api_key.global_limits.request_per_day == number + assert api_key.global_limits.request_per_minute == number + assert api_key.asset_limits[0].request_per_day == number + assert api_key.asset_limits[0].request_per_minute == number + assert api_key.asset_limits[0].token_per_day == number + assert api_key.asset_limits[0].token_per_minute == number + break + def test_list_update_api_keys(): diff --git a/tests/functional/general_assets/asset_functional_test.py b/tests/functional/general_assets/asset_functional_test.py index b0d8f6ef..266b04ea 100644 --- a/tests/functional/general_assets/asset_functional_test.py +++ b/tests/functional/general_assets/asset_functional_test.py @@ -33,7 +33,10 @@ def __get_asset_factory(asset_name): @pytest.mark.parametrize("asset_name", ["model", "dataset", "metric"]) def test_list(asset_name): AssetFactory = __get_asset_factory(asset_name) - asset_list = AssetFactory.list() + if asset_name == "model": + asset_list = AssetFactory.list(function=Function.TRANSLATION) + else: + asset_list = AssetFactory.list() assert asset_list["page_total"] == len(asset_list["results"]) @@ -62,7 +65,7 @@ def test_model_function(): def test_model_supplier(): desired_suppliers = [Supplier.GOOGLE] - models = ModelFactory.list(suppliers=desired_suppliers)["results"] + models = ModelFactory.list(suppliers=desired_suppliers, function=Function.TRANSLATION)["results"] for model in models: assert model.supplier.value in [desired_supplier.value for desired_supplier in desired_suppliers] @@ -89,14 +92,14 @@ def test_model_sort(): def test_model_ownership(): - models = ModelFactory.list(ownership=OwnershipType.SUBSCRIBED)["results"] + models = ModelFactory.list(ownership=OwnershipType.SUBSCRIBED, function=Function.TRANSLATION)["results"] for model in models: assert model.is_subscribed is True def test_model_query(): query = "Mongo" - models = ModelFactory.list(query=query)["results"] + models = ModelFactory.list(query=query, function=Function.TRANSLATION)["results"] for model in models: assert query in model.name diff --git a/tests/functional/general_assets/data/asset_run_test_data.json b/tests/functional/general_assets/data/asset_run_test_data.json index abe7a3e9..e24df1ef 100644 --- a/tests/functional/general_assets/data/asset_run_test_data.json +++ b/tests/functional/general_assets/data/asset_run_test_data.json @@ -3,6 +3,10 @@ "id" : "61b097551efecf30109d32da", "data": "This is a test sentence." }, + "model2" : { + "id" : "60ddefab8d38c51c5885ee38", + "data": "https://aixplain-platform-assets.s3.amazonaws.com/samples/en/myname.mp3" + }, "pipeline": { "name": "SingleNodePipeline", "data": "This is a test sentence." diff --git a/tests/functional/model/run_model_test.py b/tests/functional/model/run_model_test.py index 47f351bb..04335d19 100644 --- a/tests/functional/model/run_model_test.py +++ b/tests/functional/model/run_model_test.py @@ -12,11 +12,17 @@ def pytest_generate_tests(metafunc): four_weeks_ago = datetime.now(timezone.utc) - timedelta(weeks=4) models = ModelFactory.list(function=Function.TEXT_GENERATION)["results"] - predefined_models = ["Groq Llama 3 70B", "Chat GPT 3.5", "GPT-4o", "GPT 4 (32k)"] + predefined_models = [] + for predefined_model in ["Groq Llama 3 70B", "Chat GPT 3.5", "GPT-4o"]: + predefined_models.extend( + [ + m + for m in ModelFactory.list(query=predefined_model, function=Function.TEXT_GENERATION)["results"] + if m.name == predefined_model and "aiXplain-testing" not in str(m.supplier) + ] + ) recent_models = [model for model in models if model.created_at and model.created_at >= four_weeks_ago] - combined_models = recent_models + [ - ModelFactory.list(query=model, function=Function.TEXT_GENERATION)["results"][0] for model in predefined_models - ] + combined_models = recent_models + predefined_models metafunc.parametrize("llm_model", combined_models) @@ -24,10 +30,21 @@ def test_llm_run(llm_model): """Testing LLMs with history context""" assert isinstance(llm_model, LLM) - response = llm_model.run( data="What is my name?", history=[{"role": "user", "content": "Hello! My name is Thiago."}, {"role": "assistant", "content": "Hello!"}], ) assert response["status"] == "SUCCESS" assert "thiago" in response["data"].lower() + + +def test_run_async(): + """Testing Model Async""" + model = ModelFactory.get("60ddef828d38c51c5885d491") + + response = model.run_async("Test") + poll_url = response["url"] + response = model.sync_poll(poll_url) + + assert response["status"] == "SUCCESS" + assert "teste" in response["data"].lower() diff --git a/tests/unit/api_key_test.py b/tests/unit/api_key_test.py index 60d2371d..7da4e082 100644 --- a/tests/unit/api_key_test.py +++ b/tests/unit/api_key_test.py @@ -1,5 +1,5 @@ __author__ = "aixplain" -from aixplain.modules import APIKeyGlobalLimits +from aixplain.modules import APIKeyLimits from datetime import datetime import requests_mock import aixplain.utils.config as config @@ -13,7 +13,7 @@ def read_data(data_path): def test_api_key_service(): with requests_mock.Mocker() as mock: - model_id = "640b517694bf816d35a59125" + model_id = "test_asset_id" model_url = f"{config.BACKEND_URL}/sdk/models/{model_id}" model_map = read_data("tests/unit/mock_responses/model_response.json") mock.get(model_url, json=model_map) @@ -25,7 +25,7 @@ def test_api_key_service(): "accessKey": "access-key", "budget": 1000, "globalLimits": {"tpm": 100, "tpd": 1000, "rpd": 1000, "rpm": 100}, - "assetLimits": [{"assetId": model_id, "tpm": 100, "tpd": 1000, "rpd": 1000, "rpm": 100}], + "assetsLimits": [{"assetId": model_id, "tpm": 100, "tpd": 1000, "rpd": 1000, "rpm": 100}], "expiresAt": "2024-10-07T00:00:00Z", "isAdmin": False, } @@ -34,13 +34,11 @@ def test_api_key_service(): api_key = APIKeyFactory.create( name="Test API Key", asset_limits=[ - APIKeyGlobalLimits( + APIKeyLimits( model=model_id, token_per_minute=100, token_per_day=1000, request_per_day=1000, request_per_minute=100 ) ], - global_limits=APIKeyGlobalLimits( - token_per_minute=100, token_per_day=1000, request_per_day=1000, request_per_minute=100 - ), + global_limits=APIKeyLimits(token_per_minute=100, token_per_day=1000, request_per_day=1000, request_per_minute=100), budget=1000, expires_at=datetime(2024, 10, 7), ) @@ -65,3 +63,58 @@ def test_api_key_service(): mock.delete(delete_url, status_code=200) api_key.delete() + + +def test_setters(): + with requests_mock.Mocker() as mock: + model_id = "test_asset_id" + model_url = f"{config.BACKEND_URL}/sdk/models/{model_id}" + model_map = read_data("tests/unit/mock_responses/model_response.json") + mock.get(model_url, json=model_map) + + create_url = f"{config.BACKEND_URL}/sdk/api-keys" + api_key_response = { + "id": "key-id", + "name": "Name", + "accessKey": "access-key", + "budget": 1000, + "globalLimits": {"tpm": 100, "tpd": 1000, "rpd": 1000, "rpm": 100}, + "assetsLimits": [{"assetId": model_id, "tpm": 100, "tpd": 1000, "rpd": 1000, "rpm": 100}], + "expiresAt": "2024-10-07T00:00:00Z", + "isAdmin": False, + } + mock.post(create_url, json=api_key_response) + + api_key = APIKeyFactory.create( + name="Test API Key", + asset_limits=[ + APIKeyLimits( + model=model_id, token_per_minute=100, token_per_day=1000, request_per_day=1000, request_per_minute=100 + ) + ], + global_limits=APIKeyLimits(token_per_minute=100, token_per_day=1000, request_per_day=1000, request_per_minute=100), + budget=1000, + expires_at=datetime(2024, 10, 7), + ) + + api_key.set_token_per_day(1) + api_key.set_token_per_minute(1) + api_key.set_request_per_day(1) + api_key.set_request_per_minute(1) + api_key.set_token_per_day(1, model_id) + api_key.set_token_per_minute(1, model_id) + api_key.set_request_per_day(1, model_id) + api_key.set_request_per_minute(1, model_id) + + assert api_key.asset_limits[0].token_per_day == 1 + assert api_key.asset_limits[0].token_per_minute == 1 + assert api_key.asset_limits[0].request_per_day == 1 + assert api_key.asset_limits[0].request_per_minute == 1 + assert api_key.global_limits.token_per_day == 1 + assert api_key.global_limits.token_per_minute == 1 + assert api_key.global_limits.request_per_day == 1 + assert api_key.global_limits.request_per_minute == 1 + + +if __name__ == "__main__": + test_setters() diff --git a/tests/unit/benchmark_test.py b/tests/unit/benchmark_test.py new file mode 100644 index 00000000..167e4bcb --- /dev/null +++ b/tests/unit/benchmark_test.py @@ -0,0 +1,70 @@ +import requests_mock +import pytest +from urllib.parse import urljoin +from aixplain.utils import config +from aixplain.factories import MetricFactory, BenchmarkFactory +from aixplain.modules.model import Model +from aixplain.modules.dataset import Dataset + + +def test_create_benchmark_error_response(): + metric_list = [MetricFactory.get("66df3e2d6eb56336b6628171")] + with requests_mock.Mocker() as mock: + name = "test-benchmark" + dataset_list = [ + Dataset( + id="dataset1", + name="Dataset 1", + description="Test dataset", + function="test_func", + source_data="src", + target_data="tgt", + onboard_status="onboarded", + ) + ] + model_list = [ + Model(id="model1", name="Model 1", description="Test model", supplier="Test supplier", cost=10, version="v1") + ] + + url = urljoin(config.BACKEND_URL, "sdk/benchmarks") + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + + error_response = {"statusCode": 400, "message": "Invalid request"} + mock.post(url, headers=headers, json=error_response, status_code=400) + + with pytest.raises(Exception) as excinfo: + BenchmarkFactory.create(name=name, dataset_list=dataset_list, model_list=model_list, metric_list=metric_list) + + assert "Benchmark Creation Error: Status 400 - {'statusCode': 400, 'message': 'Invalid request'}" in str(excinfo.value) + + +def test_get_benchmark_error(): + with requests_mock.Mocker() as mock: + benchmark_id = "test-benchmark-id" + url = urljoin(config.BACKEND_URL, f"sdk/benchmarks/{benchmark_id}") + headers = {"x-aixplain-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + + error_response = {"statusCode": 404, "message": "Benchmark not found"} + mock.get(url, headers=headers, json=error_response, status_code=404) + + with pytest.raises(Exception) as excinfo: + BenchmarkFactory.get(benchmark_id) + + assert "Benchmark GET Error: Status 404 - {'statusCode': 404, 'message': 'Benchmark not found'}" in str(excinfo.value) + + +def test_list_normalization_options_error(): + metric = MetricFactory.get("66df3e2d6eb56336b6628171") + with requests_mock.Mocker() as mock: + model = Model(id="model1", name="Test Model", description="Test model", supplier="Test supplier", cost=10, version="v1") + + url = urljoin(config.BACKEND_URL, "sdk/benchmarks/normalization-options") + headers = {"x-aixplain-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + + error_response = {"message": "Internal Server Error"} + mock.post(url, headers=headers, json=error_response, status_code=500) + + with pytest.raises(Exception) as excinfo: + BenchmarkFactory.list_normalization_options(metric, model) + + assert "Error listing normalization options: Status 500 - {'message': 'Internal Server Error'}" in str(excinfo.value) diff --git a/tests/unit/corpus_test.py b/tests/unit/corpus_test.py new file mode 100644 index 00000000..07522c4d --- /dev/null +++ b/tests/unit/corpus_test.py @@ -0,0 +1,34 @@ +from aixplain.factories import CorpusFactory +import pytest +import requests_mock +from urllib.parse import urljoin +from aixplain.utils import config + + +def test_get_corpus_error_response(): + with requests_mock.Mocker() as mock: + corpus_id = "invalid_corpus_id" + url = urljoin(config.BACKEND_URL, f"sdk/corpora/{corpus_id}/overview") + headers = {"x-aixplain-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + + error_response = {"message": "Not Found"} + mock.get(url, headers=headers, json=error_response, status_code=404) + + with pytest.raises(Exception) as excinfo: + CorpusFactory.get(corpus_id=corpus_id) + + assert "Corpus GET Error: Status 404 - {'message': 'Not Found'}" in str(excinfo.value) + + +def test_list_corpus_error_response(): + with requests_mock.Mocker() as mock: + url = urljoin(config.BACKEND_URL, "sdk/corpora/paginate") + headers = {"x-aixplain-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + + error_response = {"message": "Internal Server Error"} + mock.post(url, headers=headers, json=error_response, status_code=500) + + with pytest.raises(Exception) as excinfo: + CorpusFactory.list(query="test_query", page_number=0, page_size=20) + + assert "Corpus List Error: Status 500 - {'message': 'Internal Server Error'}" in str(excinfo.value) diff --git a/tests/unit/dataset_test.py b/tests/unit/dataset_test.py new file mode 100644 index 00000000..25c57123 --- /dev/null +++ b/tests/unit/dataset_test.py @@ -0,0 +1,34 @@ +import pytest +import requests_mock +from aixplain.factories import DatasetFactory +from urllib.parse import urljoin +from aixplain.utils import config + + +def test_list_dataset_error_response(): + with requests_mock.Mocker() as mock: + url = urljoin(config.BACKEND_URL, "sdk/datasets/paginate") + headers = {"x-aixplain-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + + error_response = {"message": "Internal Server Error"} + mock.post(url, headers=headers, json=error_response, status_code=500) + + with pytest.raises(Exception) as excinfo: + DatasetFactory.list(query="test_query", page_number=0, page_size=20) + + assert "Dataset List Error: Status 500 - {'message': 'Internal Server Error'}" in str(excinfo.value) + + +def test_get_dataset_error_response(): + with requests_mock.Mocker() as mock: + dataset_id = "invalid_dataset_id" + url = urljoin(config.BACKEND_URL, f"sdk/datasets/{dataset_id}/overview") + headers = {"x-aixplain-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + + error_response = {"message": "Not Found"} + mock.get(url, headers=headers, json=error_response, status_code=404) + + with pytest.raises(Exception) as excinfo: + DatasetFactory.get(dataset_id=dataset_id) + + assert "Dataset GET Error: Status 404 - {'message': 'Not Found'}" in str(excinfo.value) diff --git a/tests/unit/llm_test.py b/tests/unit/llm_test.py index 430fc338..54887950 100644 --- a/tests/unit/llm_test.py +++ b/tests/unit/llm_test.py @@ -1,6 +1,4 @@ - from dotenv import load_dotenv -from urllib.parse import urljoin import requests_mock from aixplain.enums import Function @@ -10,27 +8,44 @@ import pytest + @pytest.mark.parametrize( "status_code,error_message", [ - (401,"Unauthorized API key: Please verify the spelling of the API key and its current validity."), - (465,"Subscription-related error: Please ensure that your subscription is active and has not expired."), - (475,"Billing-related error: Please ensure you have enough credits to run this model. "), - (485, "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access."), - (495, "Validation-related error: Please ensure all required fields are provided and correctly formatted."), - (501, "Status 501: Unspecified error: An unspecified error occurred while processing your request."), - + ( + 401, + "Unauthorized API key: Please verify the spelling of the API key and its current validity. Details: An unspecified error occurred while processing your request.", + ), + ( + 465, + "Subscription-related error: Please ensure that your subscription is active and has not expired. Details: An unspecified error occurred while processing your request.", + ), + ( + 475, + "Billing-related error: Please ensure you have enough credits to run this model. Details: An unspecified error occurred while processing your request.", + ), + ( + 485, + "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access. Details: An unspecified error occurred while processing your request.", + ), + ( + 495, + "An unspecified error occurred while processing your request.", + ), + (501, "Status 501 - Unspecified error: An unspecified error occurred while processing your request."), ], ) - def test_run_async_errors(status_code, error_message): base_url = config.MODELS_RUN_URL llm_id = "llm-id" - execute_url = urljoin(base_url, f"execute/{llm_id}") - + execute_url = f"{base_url}/{llm_id}" + ref_response = { + "error": "An unspecified error occurred while processing your request.", + } + with requests_mock.Mocker() as mock: - mock.post(execute_url, status_code=status_code) - test_llm = LLM(id=llm_id, name="Test llm",url=base_url, function=Function.TEXT_GENERATION) + mock.post(execute_url, status_code=status_code, json=ref_response) + test_llm = LLM(id=llm_id, name="Test llm", url=base_url, function=Function.TEXT_GENERATION) response = test_llm.run_async(data="input_data") assert response["status"] == "FAILED" - assert response["error_message"] == error_message \ No newline at end of file + assert response["error_message"] == error_message diff --git a/tests/unit/model_test.py b/tests/unit/model_test.py index c52bb950..03dccdbe 100644 --- a/tests/unit/model_test.py +++ b/tests/unit/model_test.py @@ -17,17 +17,66 @@ """ from dotenv import load_dotenv -from urllib.parse import urljoin import requests_mock load_dotenv() import re +import json from aixplain.utils import config from aixplain.modules import Model +from aixplain.modules.model.utils import build_payload, call_run_endpoint +from aixplain.factories import ModelFactory +from aixplain.enums import Function +from urllib.parse import urljoin import pytest +def test_build_payload(): + data = "input_data" + parameters = {"context": "context_data"} + ref_payload = json.dumps({"data": data, **parameters}) + hyp_payload = build_payload(data, parameters) + assert hyp_payload == ref_payload + + +def test_call_run_endpoint_async(): + base_url = config.MODELS_RUN_URL + model_id = "model-id" + execute_url = f"{base_url}/{model_id}" + payload = {"data": "input_data"} + ref_response = { + "completed": True, + "status": "IN_PROGRESS", + "data": "https://models.aixplain.com/api/v1/data/a90c2078-edfe-403f-acba-d2d94cf71f42", + } + + with requests_mock.Mocker() as mock: + mock.post(execute_url, json=ref_response) + response = call_run_endpoint(url=execute_url, api_key=config.TEAM_API_KEY, payload=payload) + + print(response) + assert response["completed"] == ref_response["completed"] + assert response["status"] == ref_response["status"] + assert response["url"] == ref_response["data"] + + +def test_call_run_endpoint_sync(): + base_url = config.MODELS_RUN_URL + model_id = "model-id" + execute_url = f"{base_url}/{model_id}".replace("/api/v1/execute", "/api/v2/execute") + payload = {"data": "input_data"} + ref_response = {"completed": True, "status": "SUCCESS", "data": "Hello"} + + with requests_mock.Mocker() as mock: + mock.post(execute_url, json=ref_response) + response = call_run_endpoint(url=execute_url, api_key=config.TEAM_API_KEY, payload=payload) + + assert response["completed"] == ref_response["completed"] + assert response["status"] == ref_response["status"] + assert response["data"] == ref_response["data"] + + def test_success_poll(): with requests_mock.Mocker() as mock: poll_url = "https://models.aixplain.com/api/v1/data/a90c2078-edfe-403f-acba-d2d94cf71f42" @@ -63,22 +112,80 @@ def test_failed_poll(): @pytest.mark.parametrize( "status_code,error_message", [ - (401, "Unauthorized API key: Please verify the spelling of the API key and its current validity."), - (465, "Subscription-related error: Please ensure that your subscription is active and has not expired."), - (475, "Billing-related error: Please ensure you have enough credits to run this model. "), - (485, "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access."), - (495, "Validation-related error: Please ensure all required fields are provided and correctly formatted."), - (501, "Status 501: Unspecified error: An unspecified error occurred while processing your request."), + ( + 401, + "Unauthorized API key: Please verify the spelling of the API key and its current validity. Details: An unspecified error occurred while processing your request.", + ), + ( + 465, + "Subscription-related error: Please ensure that your subscription is active and has not expired. Details: An unspecified error occurred while processing your request.", + ), + ( + 475, + "Billing-related error: Please ensure you have enough credits to run this model. Details: An unspecified error occurred while processing your request.", + ), + ( + 485, + "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access. Details: An unspecified error occurred while processing your request.", + ), + ( + 495, + "An unspecified error occurred while processing your request.", + ), + (501, "Status 501 - Unspecified error: An unspecified error occurred while processing your request."), ], ) def test_run_async_errors(status_code, error_message): base_url = config.MODELS_RUN_URL model_id = "model-id" - execute_url = urljoin(base_url, f"execute/{model_id}") + execute_url = f"{base_url}/{model_id}" + ref_response = { + "error": "An unspecified error occurred while processing your request.", + } with requests_mock.Mocker() as mock: - mock.post(execute_url, status_code=status_code) + mock.post(execute_url, status_code=status_code, json=ref_response) test_model = Model(id=model_id, name="Test Model", url=base_url) response = test_model.run_async(data="input_data") assert response["status"] == "FAILED" assert response["error_message"] == error_message + + +def test_get_model_error_response(): + with requests_mock.Mocker() as mock: + model_id = "test-model-id" + url = urljoin(config.BACKEND_URL, f"sdk/models/{model_id}") + headers = {"x-aixplain-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + + error_response = {"statusCode": 404, "message": "Model not found"} + mock.get(url, headers=headers, json=error_response, status_code=404) + + with pytest.raises(Exception) as excinfo: + ModelFactory.get(model_id) + + assert "Model GET Error: Failed to retrieve model test-model-id" in str(excinfo.value) + + +def test_get_assets_from_page_error(): + with requests_mock.Mocker() as mock: + query = "test-query" + page_number = 0 + page_size = 2 + url = urljoin(config.BACKEND_URL, "sdk/models/paginate") + headers = {"x-aixplain-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + + error_response = {"statusCode": 500, "message": "Internal Server Error"} + mock.post(url, headers=headers, json=error_response, status_code=500) + + with pytest.raises(Exception) as excinfo: + ModelFactory._get_assets_from_page( + query=query, + page_number=page_number, + page_size=page_size, + function=Function.TEXT_GENERATION, + suppliers=None, + source_languages=None, + target_languages=None, + ) + + assert "Listing Models Error: Failed to retrieve models" in str(excinfo.value) diff --git a/tests/unit/pipeline_test.py b/tests/unit/pipeline_test.py index d3c1c725..05ee7172 100644 --- a/tests/unit/pipeline_test.py +++ b/tests/unit/pipeline_test.py @@ -38,27 +38,61 @@ def test_create_pipeline(): assert hyp_pipeline.id == ref_pipeline.id assert hyp_pipeline.name == ref_pipeline.name + @pytest.mark.parametrize( "status_code,error_message", [ - (401,"Unauthorized API key: Please verify the spelling of the API key and its current validity."), - (465,"Subscription-related error: Please ensure that your subscription is active and has not expired."), - (475,"Billing-related error: Please ensure you have enough credits to run this pipeline. "), - (485, "Supplier-related error: Please ensure that the selected supplier provides the pipeline you are trying to access."), + (401, "Unauthorized API key: Please verify the spelling of the API key and its current validity."), + (465, "Subscription-related error: Please ensure that your subscription is active and has not expired."), + (475, "Billing-related error: Please ensure you have enough credits to run this pipeline. "), + ( + 485, + "Supplier-related error: Please ensure that the selected supplier provides the pipeline you are trying to access.", + ), (495, "Validation-related error: Please ensure all required fields are provided and correctly formatted."), (501, "Status 501: Unspecified error: An unspecified error occurred while processing your request."), - ], ) - def test_run_async_errors(status_code, error_message): base_url = config.BACKEND_URL pipeline_id = "pipeline_id" execute_url = f"{base_url}/assets/pipeline/execution/run/{pipeline_id}" - + with requests_mock.Mocker() as mock: mock.post(execute_url, status_code=status_code) test_pipeline = Pipeline(id=pipeline_id, api_key=config.TEAM_API_KEY, name="Test Pipeline", url=base_url) response = test_pipeline.run_async(data="input_data") assert response["status"] == "FAILED" - assert response["error_message"] == error_message \ No newline at end of file + assert response["error_message"] == error_message + + +def test_list_pipelines_error_response(): + with requests_mock.Mocker() as mock: + query = "test-query" + page_number = 0 + page_size = 20 + url = urljoin(config.BACKEND_URL, "sdk/pipelines/paginate") + headers = {"x-aixplain-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + + error_response = {"statusCode": 400, "message": "Bad Request"} + mock.post(url, headers=headers, json=error_response, status_code=400) + + with pytest.raises(Exception) as excinfo: + PipelineFactory.list(query=query, page_number=page_number, page_size=page_size) + + assert "Pipeline List Error: Failed to retrieve pipelines. Status Code: 400" in str(excinfo.value) + + +def test_get_pipeline_error_response(): + with requests_mock.Mocker() as mock: + pipeline_id = "test-pipeline-id" + url = urljoin(config.BACKEND_URL, f"sdk/pipelines/{pipeline_id}") + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + + error_response = {"statusCode": 404, "message": "Pipeline not found"} + mock.get(url, headers=headers, json=error_response, status_code=404) + + with pytest.raises(Exception) as excinfo: + PipelineFactory.get(pipeline_id=pipeline_id) + + assert "Pipeline GET Error: Failed to retrieve pipeline test-pipeline-id. Status Code: 404" in str(excinfo.value)