diff --git a/aixplain/factories/api_key_factory.py b/aixplain/factories/api_key_factory.py index 4ac8f00a..c719c26b 100644 --- a/aixplain/factories/api_key_factory.py +++ b/aixplain/factories/api_key_factory.py @@ -4,12 +4,20 @@ from datetime import datetime from typing import Text, List, Optional, Dict, Union from aixplain.utils.file_utils import _request_with_retry -from aixplain.modules.api_key import APIKey, APIKeyGlobalLimits, APIKeyUsageLimit +from aixplain.modules.api_key import APIKey, APIKeyLimits, APIKeyUsageLimit class APIKeyFactory: backend_url = config.BACKEND_URL + @classmethod + def get(cls, api_key: Text) -> APIKey: + """Get an API key""" + for api_key_obj in cls.list(): + if str(api_key_obj.access_key).startswith(api_key[:4]) and str(api_key_obj.access_key).endswith(api_key[-4:]): + return api_key_obj + raise Exception(f"API Key Error: API key {api_key} not found") + @classmethod def list(cls) -> List[APIKey]: """List all API keys""" @@ -30,7 +38,7 @@ def list(cls) -> List[APIKey]: name=key["name"], budget=key["budget"] if "budget" in key else None, global_limits=key["globalLimits"] if "globalLimits" in key else None, - asset_limits=key["assetLimits"] if "assetLimits" in key else [], + asset_limits=key["assetsLimits"] if "assetsLimits" in key else [], expires_at=key["expiresAt"] if "expiresAt" in key else None, access_key=key["accessKey"], is_admin=key["isAdmin"], @@ -46,8 +54,8 @@ def create( cls, name: Text, budget: int, - global_limits: Union[Dict, APIKeyGlobalLimits], - asset_limits: List[Union[Dict, APIKeyGlobalLimits]], + global_limits: Union[Dict, APIKeyLimits], + asset_limits: List[Union[Dict, APIKeyLimits]], expires_at: datetime, ) -> APIKey: """Create a new API key""" @@ -84,6 +92,7 @@ def create( @classmethod def update(cls, api_key: APIKey) -> APIKey: """Update an existing API key""" + api_key.validate() try: resp = "Unspecified error" url = f"{cls.backend_url}/sdk/api-keys/{api_key.id}" @@ -112,12 +121,10 @@ def update(cls, api_key: APIKey) -> APIKey: raise Exception(f"API Key Update Error: Failed to update API key with ID {api_key.id}. Error: {str(resp)}") @classmethod - def get_usage_limit(cls, api_key: Text = config.TEAM_API_KEY, asset_id: Optional[Text] = None) -> APIKeyUsageLimit: - """Get API key usage limit""" + def get_usage_limits(cls, api_key: Text = config.TEAM_API_KEY, asset_id: Optional[Text] = None) -> List[APIKeyUsageLimit]: + """Get API key usage limits""" try: url = f"{config.BACKEND_URL}/sdk/api-keys/usage-limits" - if asset_id is not None: - url += f"?assetId={asset_id}" headers = {"Authorization": f"Token {api_key}", "Content-Type": "application/json"} logging.info(f"Start service for GET API Key Usage - {url} - {headers}") r = _request_with_retry("GET", url, headers=headers) @@ -128,11 +135,16 @@ def get_usage_limit(cls, api_key: Text = config.TEAM_API_KEY, asset_id: Optional raise Exception(f"{message}") if 200 <= r.status_code < 300: - return APIKeyUsageLimit( - request_count=resp["requestCount"], - request_count_limit=resp["requestCountLimit"], - token_count=resp["tokenCount"], - token_count_limit=resp["tokenCountLimit"], - ) + return [ + APIKeyUsageLimit( + daily_request_count=limit["requestCount"], + daily_request_limit=limit["requestCountLimit"], + daily_token_count=limit["tokenCount"], + daily_token_limit=limit["tokenCountLimit"], + model=limit["assetId"] if "assetId" in limit else None, + ) + for limit in resp + if asset_id is None or ("assetId" in limit and limit["assetId"] == asset_id) + ] else: raise Exception(f"API Key Usage Error: Failed to get usage. Error: {str(resp)}") diff --git a/aixplain/factories/model_factory.py b/aixplain/factories/model_factory.py index 5df7c924..209ff75d 100644 --- a/aixplain/factories/model_factory.py +++ b/aixplain/factories/model_factory.py @@ -80,6 +80,7 @@ def _create_model_from_response(cls, response: Dict) -> Model: return ModelClass( response["id"], response["name"], + description=response.get("description", ""), supplier=response["supplier"], api_key=response["api_key"], cost=response["pricing"], @@ -221,8 +222,8 @@ def _get_assets_from_page( @classmethod def list( cls, + function: Function, query: Optional[Text] = "", - function: Optional[Function] = None, suppliers: Optional[Union[Supplier, List[Supplier]]] = None, source_languages: Optional[Union[Language, List[Language]]] = None, target_languages: Optional[Union[Language, List[Language]]] = None, @@ -236,7 +237,7 @@ def list( """Gets the first k given models based on the provided task and language filters Args: - function (Optional[Function], optional): function filter. Defaults to None. + function (Function): function filter. source_languages (Optional[Union[Language, List[Language]]], optional): language filter of input data. Defaults to None. target_languages (Optional[Union[Language, List[Language]]], optional): language filter of output data. Defaults to None. is_finetunable (Optional[bool], optional): can be finetuned or not. Defaults to None. diff --git a/aixplain/factories/pipeline_factory/utils.py b/aixplain/factories/pipeline_factory/utils.py index 9584863f..7911c370 100644 --- a/aixplain/factories/pipeline_factory/utils.py +++ b/aixplain/factories/pipeline_factory/utils.py @@ -86,8 +86,23 @@ def build_from_response(response: Dict, load_architecture: bool = False) -> Pipe node.label = node_json["label"] pipeline.add_node(node) + # Decision nodes' output parameters are defined based on their + # input parameters linked. So here we have to make sure that + # decision nodes (having passthrough parameter) should be first + # linked + link_jsons = response["links"][:] + decision_links = [] + for link_json in link_jsons: + for pm in link_json["paramMapping"]: + if pm["to"] == "passthrough": + decision_link_index = link_jsons.index(link_json) + decision_link = link_jsons.pop(decision_link_index) + decision_links.append(decision_link) + + link_jsons = decision_links + link_jsons + # instantiating links - for link_json in response["links"]: + for link_json in link_jsons: for param_mapping in link_json["paramMapping"]: link = Link( from_node=pipeline.get_node(link_json["from"]), diff --git a/aixplain/modules/__init__.py b/aixplain/modules/__init__.py index d49e29d4..4432e1ad 100644 --- a/aixplain/modules/__init__.py +++ b/aixplain/modules/__init__.py @@ -36,4 +36,4 @@ from .agent import Agent from .agent.tool import Tool from .team_agent import TeamAgent -from .api_key import APIKey, APIKeyGlobalLimits, APIKeyUsageLimit +from .api_key import APIKey, APIKeyLimits, APIKeyUsageLimit diff --git a/aixplain/modules/api_key.py b/aixplain/modules/api_key.py index 886b0dab..ae774c23 100644 --- a/aixplain/modules/api_key.py +++ b/aixplain/modules/api_key.py @@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Text, Union -class APIKeyGlobalLimits: +class APIKeyLimits: def __init__( self, token_per_minute: int, @@ -27,19 +27,31 @@ def __init__( class APIKeyUsageLimit: - def __init__(self, request_count: int, request_count_limit: int, token_count: int, token_count_limit: int): - """Get the usage limits of an API key + def __init__( + self, + daily_request_count: int, + daily_request_limit: int, + daily_token_count: int, + daily_token_limit: int, + model: Optional[Union[Text, Model]] = None, + ): + """Get the usage limits of an API key globally (model equals to None) or for a specific model. Args: - request_count (int): number of requests made - request_count_limit (int): limit of requests - token_count (int): number of tokens used - token_count_limit (int): limit of tokens + daily_request_count (int): number of requests made + daily_request_limit (int): limit of requests + daily_token_count (int): number of tokens used + daily_token_limit (int): limit of tokens + model (Optional[Union[Text, Model]], optional): Model which the limits apply. Defaults to None. """ - self.request_count = request_count - self.request_count_limit = request_count_limit - self.token_count = token_count - self.token_count_limit = token_count_limit + self.daily_request_count = daily_request_count + self.daily_request_limit = daily_request_limit + self.daily_token_count = daily_token_count + self.daily_token_limit = daily_token_limit + if model is not None and isinstance(model, str): + from aixplain.factories import ModelFactory + + self.model = ModelFactory.get(model) class APIKey: @@ -48,8 +60,8 @@ def __init__( name: Text, expires_at: Optional[Union[datetime, Text]] = None, budget: Optional[float] = None, - asset_limits: List[APIKeyGlobalLimits] = [], - global_limits: Optional[Union[Dict, APIKeyGlobalLimits]] = None, + asset_limits: List[APIKeyLimits] = [], + global_limits: Optional[Union[Dict, APIKeyLimits]] = None, id: int = "", access_key: Optional[Text] = None, is_admin: bool = False, @@ -59,7 +71,7 @@ def __init__( self.budget = budget self.global_limits = global_limits if global_limits is not None and isinstance(global_limits, dict): - self.global_limits = APIKeyGlobalLimits( + self.global_limits = APIKeyLimits( token_per_minute=global_limits["tpm"], token_per_day=global_limits["tpd"], request_per_minute=global_limits["rpm"], @@ -68,7 +80,7 @@ def __init__( self.asset_limits = asset_limits for i, asset_limit in enumerate(self.asset_limits): if isinstance(asset_limit, dict): - self.asset_limits[i] = APIKeyGlobalLimits( + self.asset_limits[i] = APIKeyLimits( token_per_minute=asset_limit["tpm"], token_per_day=asset_limit["tpd"], request_per_minute=asset_limit["rpm"], @@ -110,7 +122,7 @@ def to_dict(self) -> Dict: "id": self.id, "name": self.name, "budget": self.budget, - "assetLimits": [], + "assetsLimits": [], "expiresAt": self.expires_at, } @@ -126,7 +138,7 @@ def to_dict(self) -> Dict: } for i, asset_limit in enumerate(self.asset_limits): - payload["assetLimits"].append( + payload["assetsLimits"].append( { "tpm": asset_limit.token_per_minute, "tpd": asset_limit.token_per_day, @@ -157,8 +169,6 @@ def get_usage(self, asset_id: Optional[Text] = None) -> APIKeyUsageLimit: url = f"{config.BACKEND_URL}/sdk/api-keys/{self.id}/usage-limits" headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} logging.info(f"Start service for GET API Key Usage - {url} - {headers}") - if asset_id is not None: - url += f"?assetId={asset_id}" r = _request_with_retry("GET", url, headers=headers) resp = r.json() except Exception: @@ -167,11 +177,48 @@ def get_usage(self, asset_id: Optional[Text] = None) -> APIKeyUsageLimit: raise Exception(f"{message}") if 200 <= r.status_code < 300: - return APIKeyUsageLimit( - request_count=resp["requestCount"], - request_count_limit=resp["requestCountLimit"], - token_count=resp["tokenCount"], - token_count_limit=resp["tokenCountLimit"], - ) + return [ + APIKeyUsageLimit( + daily_request_count=limit["requestCount"], + daily_request_limit=limit["requestCountLimit"], + daily_token_count=limit["tokenCount"], + daily_token_limit=limit["tokenCountLimit"], + model=limit["assetId"] if "assetId" in limit else None, + ) + for limit in resp + if asset_id is None or ("assetId" in limit and limit["assetId"] == asset_id) + ] else: raise Exception(f"API Key Usage Error: Failed to get usage. Error: {str(resp)}") + + def __set_limit(self, limit: int, model: Optional[Union[Text, Model]], limit_type: Text) -> None: + """Set a limit for an API key""" + if model is None: + setattr(self.global_limits, limit_type, limit) + else: + if isinstance(model, Model): + model = model.id + is_found = False + for i, asset_limit in enumerate(self.asset_limits): + if asset_limit.model.id == model: + setattr(self.asset_limits[i], limit_type, limit) + is_found = True + break + if is_found is False: + raise Exception(f"Limit for Model {model} not found in the API key.") + + def set_token_per_day(self, token_per_day: int, model: Optional[Union[Text, Model]] = None) -> None: + """Set the token per day limit of an API key""" + self.__set_limit(token_per_day, model, "token_per_day") + + def set_token_per_minute(self, token_per_minute: int, model: Optional[Union[Text, Model]] = None) -> None: + """Set the token per minute limit of an API key""" + self.__set_limit(token_per_minute, model, "token_per_minute") + + def set_request_per_day(self, request_per_day: int, model: Optional[Union[Text, Model]] = None) -> None: + """Set the request per day limit of an API key""" + self.__set_limit(request_per_day, model, "request_per_day") + + def set_request_per_minute(self, request_per_minute: int, model: Optional[Union[Text, Model]] = None) -> None: + """Set the request per minute limit of an API key""" + self.__set_limit(request_per_minute, model, "request_per_minute") diff --git a/aixplain/modules/model/utils.py b/aixplain/modules/model/utils.py index a78455b7..d29da68b 100644 --- a/aixplain/modules/model/utils.py +++ b/aixplain/modules/model/utils.py @@ -55,8 +55,9 @@ def call_run_endpoint(url: Text, api_key: Text, payload: Dict) -> Dict: "error_message": "Model Run: An error occurred while processing your request.", } else: - response = {"status": status, "data": data, "completed": True} + response = resp else: + resp = resp["error"] if "error" in resp else resp if r.status_code == 401: error = f"Unauthorized API key: Please verify the spelling of the API key and its current validity. Details: {resp}" elif 460 <= r.status_code < 470: @@ -66,7 +67,7 @@ def call_run_endpoint(url: Text, api_key: Text, payload: Dict) -> Dict: elif 480 <= r.status_code < 490: error = f"Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access. Details: {resp}" elif 490 <= r.status_code < 500: - error = f"Validation-related error: Please ensure all required fields are provided and correctly formatted. Details: {resp}" + error = f"{resp}" else: status_code = str(r.status_code) error = f"Status {status_code} - Unspecified error: {resp}" diff --git a/aixplain/modules/pipeline/designer/base.py b/aixplain/modules/pipeline/designer/base.py index 76e6196d..49c68463 100644 --- a/aixplain/modules/pipeline/designer/base.py +++ b/aixplain/modules/pipeline/designer/base.py @@ -142,14 +142,34 @@ def __init__( pipeline: "DesignerPipeline" = None, ): - assert from_param in from_node.outputs, "Invalid from param" - assert to_param in to_node.inputs, "Invalid to param" - if isinstance(from_param, Param): from_param = from_param.code if isinstance(to_param, Param): to_param = to_param.code + assert from_param in from_node.outputs, ( + "Invalid from param. " + "Make sure all input params are already linked accordingly" + ) + + fp_instance = from_node.outputs[from_param] + from .nodes import Decision + + if ( + isinstance(to_node, Decision) + and to_param == to_node.inputs.passthrough.code + ): + if from_param not in to_node.outputs: + to_node.outputs.create_param( + from_param, + fp_instance.data_type, + is_required=fp_instance.is_required, + ) + else: + to_node.outputs[from_param].data_type = fp_instance.data_type + + assert to_param in to_node.inputs, "Invalid to param" + self.from_node = from_node self.to_node = to_node self.from_param = from_param @@ -233,9 +253,7 @@ def __init__(self, node: "Node", *args, **kwargs): def add_param(self, param: Param) -> None: # check if param already registered if param in self: - raise ValueError( - f"Parameter with code '{param.code}' already exists." - ) + raise ValueError(f"Parameter with code '{param.code}' already exists.") self._params.append(param) # also set attribute on the node dynamically if there's no # any attribute with the same name @@ -353,9 +371,7 @@ def attach_to(self, pipeline: "DesignerPipeline"): :param pipeline: the pipeline """ assert not self.pipeline, "Node already attached to a pipeline" - assert ( - self not in pipeline.nodes - ), "Node already attached to a pipeline" + assert self not in pipeline.nodes, "Node already attached to a pipeline" assert self.type, "Node type not set" self.pipeline = pipeline diff --git a/aixplain/modules/pipeline/designer/nodes.py b/aixplain/modules/pipeline/designer/nodes.py index a6879e04..70ff302f 100644 --- a/aixplain/modules/pipeline/designer/nodes.py +++ b/aixplain/modules/pipeline/designer/nodes.py @@ -288,11 +288,15 @@ def __init__(self, value: DataType, path: List[Union[Node, int]], operation: Ope self.operation = operation self.type = type - if not self.path: - raise ValueError("Path is not valid, should be a list of nodes") + # Path can be an empty list in case the user has a valid case + # if not self.path: + # raise ValueError("Path is not valid, should be a list of nodes") # convert nodes to node numbers if they are nodes - self.path = [node.number if isinstance(node, Node) else node for node in self.path] + self.path = [ + node.number if isinstance(node, Node) else node + for node in self.path + ] def serialize(self) -> dict: return { diff --git a/pyproject.toml b/pyproject.toml index e0df02a2..1f034299 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ namespaces = true [project] name = "aiXplain" -version = "0.2.21rc0" +version = "0.2.21rc1" description = "aiXplain SDK adds AI functions to software." readme = "README.md" requires-python = ">=3.5, <4" diff --git a/tests/functional/apikey/test_api.py b/tests/functional/apikey/test_api.py index 80b75189..221a58fb 100644 --- a/tests/functional/apikey/test_api.py +++ b/tests/functional/apikey/test_api.py @@ -1,5 +1,5 @@ from aixplain.factories.api_key_factory import APIKeyFactory -from aixplain.modules import APIKey, APIKeyGlobalLimits, APIKeyUsageLimit +from aixplain.modules import APIKey, APIKeyLimits, APIKeyUsageLimit from datetime import datetime import json import pytest @@ -16,7 +16,7 @@ def test_create_api_key_from_json(): api_key = APIKeyFactory.create( name=api_key_data["name"], asset_limits=[ - APIKeyGlobalLimits( + APIKeyLimits( model=api_key_data["asset_limits"][0]["model"], token_per_minute=api_key_data["asset_limits"][0]["token_per_minute"], token_per_day=api_key_data["asset_limits"][0]["token_per_day"], @@ -24,7 +24,7 @@ def test_create_api_key_from_json(): request_per_minute=api_key_data["asset_limits"][0]["request_per_minute"], ) ], - global_limits=APIKeyGlobalLimits( + global_limits=APIKeyLimits( token_per_minute=api_key_data["global_limits"]["token_per_minute"], token_per_day=api_key_data["global_limits"]["token_per_day"], request_per_day=api_key_data["global_limits"]["request_per_day"], @@ -60,8 +60,8 @@ def test_create_api_key_from_dict(): api_key_name = "Test API Key" api_key = APIKeyFactory.create( name=api_key_name, - asset_limits=[APIKeyGlobalLimits(**limit) for limit in api_key_dict["asset_limits"]], - global_limits=APIKeyGlobalLimits(**api_key_dict["global_limits"]), + asset_limits=[APIKeyLimits(**limit) for limit in api_key_dict["asset_limits"]], + global_limits=APIKeyLimits(**api_key_dict["global_limits"]), budget=api_key_dict["budget"], expires_at=datetime.strptime(api_key_dict["expires_at"], "%Y-%m-%dT%H:%M:%SZ"), ) @@ -92,8 +92,8 @@ def test_create_update_api_key_from_dict(): api_key_name = "Test API Key" api_key = APIKeyFactory.create( name=api_key_name, - asset_limits=[APIKeyGlobalLimits(**limit) for limit in api_key_dict["asset_limits"]], - global_limits=APIKeyGlobalLimits(**api_key_dict["global_limits"]), + asset_limits=[APIKeyLimits(**limit) for limit in api_key_dict["asset_limits"]], + global_limits=APIKeyLimits(**api_key_dict["global_limits"]), budget=api_key_dict["budget"], expires_at=datetime.strptime(api_key_dict["expires_at"], "%Y-%m-%dT%H:%M:%SZ"), ) @@ -102,6 +102,11 @@ def test_create_update_api_key_from_dict(): assert api_key.id != "" assert api_key.name == api_key_name + api_key_ = APIKeyFactory.get(api_key=api_key.access_key) + assert isinstance(api_key_, APIKey) + assert api_key_.id != "" + assert api_key_.name == api_key_name + api_key.global_limits.token_per_day = 222 api_key.global_limits.token_per_minute = 222 api_key.global_limits.request_per_day = 222 @@ -134,7 +139,65 @@ def test_list_api_keys(): if api_key.is_admin is False: usage = api_key.get_usage() - assert isinstance(usage, APIKeyUsageLimit) + assert isinstance(usage, list) + if len(usage) > 0: + assert isinstance(usage[0], APIKeyUsageLimit) + + +def test_list_update_api_keys(): + api_keys = APIKeyFactory.list() + assert isinstance(api_keys, list) + + for api_key in api_keys: + assert isinstance(api_key, APIKey) + assert api_key.id != "" + + from random import randint + + number = randint(0, 10000) + if api_key.global_limits is None: + api_key.global_limits = APIKeyLimits( + token_per_minute=number, + token_per_day=number, + request_per_day=number, + request_per_minute=number, + ) + else: + api_key.global_limits.token_per_day = number + api_key.global_limits.token_per_minute = number + api_key.global_limits.request_per_day = number + api_key.global_limits.request_per_minute = number + + if api_key.asset_limits is None: + api_key.asset_limits = [] + + if len(api_key.asset_limits) == 0: + api_key.asset_limits.append( + APIKeyLimits( + model="640b517694bf816d35a59125", + token_per_minute=number, + token_per_day=number, + request_per_day=number, + request_per_minute=number, + ) + ) + else: + api_key.asset_limits[0].request_per_day = number + api_key.asset_limits[0].request_per_minute = number + api_key.asset_limits[0].token_per_day = number + api_key.asset_limits[0].token_per_minute = number + api_key = APIKeyFactory.update(api_key) + + assert api_key.global_limits.token_per_day == number + assert api_key.global_limits.token_per_minute == number + assert api_key.global_limits.request_per_day == number + assert api_key.global_limits.request_per_minute == number + assert api_key.asset_limits[0].request_per_day == number + assert api_key.asset_limits[0].request_per_minute == number + assert api_key.asset_limits[0].token_per_day == number + assert api_key.asset_limits[0].token_per_minute == number + break + def test_list_update_api_keys(): diff --git a/tests/functional/general_assets/asset_functional_test.py b/tests/functional/general_assets/asset_functional_test.py index b0d8f6ef..266b04ea 100644 --- a/tests/functional/general_assets/asset_functional_test.py +++ b/tests/functional/general_assets/asset_functional_test.py @@ -33,7 +33,10 @@ def __get_asset_factory(asset_name): @pytest.mark.parametrize("asset_name", ["model", "dataset", "metric"]) def test_list(asset_name): AssetFactory = __get_asset_factory(asset_name) - asset_list = AssetFactory.list() + if asset_name == "model": + asset_list = AssetFactory.list(function=Function.TRANSLATION) + else: + asset_list = AssetFactory.list() assert asset_list["page_total"] == len(asset_list["results"]) @@ -62,7 +65,7 @@ def test_model_function(): def test_model_supplier(): desired_suppliers = [Supplier.GOOGLE] - models = ModelFactory.list(suppliers=desired_suppliers)["results"] + models = ModelFactory.list(suppliers=desired_suppliers, function=Function.TRANSLATION)["results"] for model in models: assert model.supplier.value in [desired_supplier.value for desired_supplier in desired_suppliers] @@ -89,14 +92,14 @@ def test_model_sort(): def test_model_ownership(): - models = ModelFactory.list(ownership=OwnershipType.SUBSCRIBED)["results"] + models = ModelFactory.list(ownership=OwnershipType.SUBSCRIBED, function=Function.TRANSLATION)["results"] for model in models: assert model.is_subscribed is True def test_model_query(): query = "Mongo" - models = ModelFactory.list(query=query)["results"] + models = ModelFactory.list(query=query, function=Function.TRANSLATION)["results"] for model in models: assert query in model.name diff --git a/tests/unit/api_key_test.py b/tests/unit/api_key_test.py index 60d2371d..7da4e082 100644 --- a/tests/unit/api_key_test.py +++ b/tests/unit/api_key_test.py @@ -1,5 +1,5 @@ __author__ = "aixplain" -from aixplain.modules import APIKeyGlobalLimits +from aixplain.modules import APIKeyLimits from datetime import datetime import requests_mock import aixplain.utils.config as config @@ -13,7 +13,7 @@ def read_data(data_path): def test_api_key_service(): with requests_mock.Mocker() as mock: - model_id = "640b517694bf816d35a59125" + model_id = "test_asset_id" model_url = f"{config.BACKEND_URL}/sdk/models/{model_id}" model_map = read_data("tests/unit/mock_responses/model_response.json") mock.get(model_url, json=model_map) @@ -25,7 +25,7 @@ def test_api_key_service(): "accessKey": "access-key", "budget": 1000, "globalLimits": {"tpm": 100, "tpd": 1000, "rpd": 1000, "rpm": 100}, - "assetLimits": [{"assetId": model_id, "tpm": 100, "tpd": 1000, "rpd": 1000, "rpm": 100}], + "assetsLimits": [{"assetId": model_id, "tpm": 100, "tpd": 1000, "rpd": 1000, "rpm": 100}], "expiresAt": "2024-10-07T00:00:00Z", "isAdmin": False, } @@ -34,13 +34,11 @@ def test_api_key_service(): api_key = APIKeyFactory.create( name="Test API Key", asset_limits=[ - APIKeyGlobalLimits( + APIKeyLimits( model=model_id, token_per_minute=100, token_per_day=1000, request_per_day=1000, request_per_minute=100 ) ], - global_limits=APIKeyGlobalLimits( - token_per_minute=100, token_per_day=1000, request_per_day=1000, request_per_minute=100 - ), + global_limits=APIKeyLimits(token_per_minute=100, token_per_day=1000, request_per_day=1000, request_per_minute=100), budget=1000, expires_at=datetime(2024, 10, 7), ) @@ -65,3 +63,58 @@ def test_api_key_service(): mock.delete(delete_url, status_code=200) api_key.delete() + + +def test_setters(): + with requests_mock.Mocker() as mock: + model_id = "test_asset_id" + model_url = f"{config.BACKEND_URL}/sdk/models/{model_id}" + model_map = read_data("tests/unit/mock_responses/model_response.json") + mock.get(model_url, json=model_map) + + create_url = f"{config.BACKEND_URL}/sdk/api-keys" + api_key_response = { + "id": "key-id", + "name": "Name", + "accessKey": "access-key", + "budget": 1000, + "globalLimits": {"tpm": 100, "tpd": 1000, "rpd": 1000, "rpm": 100}, + "assetsLimits": [{"assetId": model_id, "tpm": 100, "tpd": 1000, "rpd": 1000, "rpm": 100}], + "expiresAt": "2024-10-07T00:00:00Z", + "isAdmin": False, + } + mock.post(create_url, json=api_key_response) + + api_key = APIKeyFactory.create( + name="Test API Key", + asset_limits=[ + APIKeyLimits( + model=model_id, token_per_minute=100, token_per_day=1000, request_per_day=1000, request_per_minute=100 + ) + ], + global_limits=APIKeyLimits(token_per_minute=100, token_per_day=1000, request_per_day=1000, request_per_minute=100), + budget=1000, + expires_at=datetime(2024, 10, 7), + ) + + api_key.set_token_per_day(1) + api_key.set_token_per_minute(1) + api_key.set_request_per_day(1) + api_key.set_request_per_minute(1) + api_key.set_token_per_day(1, model_id) + api_key.set_token_per_minute(1, model_id) + api_key.set_request_per_day(1, model_id) + api_key.set_request_per_minute(1, model_id) + + assert api_key.asset_limits[0].token_per_day == 1 + assert api_key.asset_limits[0].token_per_minute == 1 + assert api_key.asset_limits[0].request_per_day == 1 + assert api_key.asset_limits[0].request_per_minute == 1 + assert api_key.global_limits.token_per_day == 1 + assert api_key.global_limits.token_per_minute == 1 + assert api_key.global_limits.request_per_day == 1 + assert api_key.global_limits.request_per_minute == 1 + + +if __name__ == "__main__": + test_setters() diff --git a/tests/unit/llm_test.py b/tests/unit/llm_test.py index b0dbe19a..54887950 100644 --- a/tests/unit/llm_test.py +++ b/tests/unit/llm_test.py @@ -14,25 +14,25 @@ [ ( 401, - "Unauthorized API key: Please verify the spelling of the API key and its current validity. Details: {'error': 'An unspecified error occurred while processing your request.'}", + "Unauthorized API key: Please verify the spelling of the API key and its current validity. Details: An unspecified error occurred while processing your request.", ), ( 465, - "Subscription-related error: Please ensure that your subscription is active and has not expired. Details: {'error': 'An unspecified error occurred while processing your request.'}", + "Subscription-related error: Please ensure that your subscription is active and has not expired. Details: An unspecified error occurred while processing your request.", ), ( 475, - "Billing-related error: Please ensure you have enough credits to run this model. Details: {'error': 'An unspecified error occurred while processing your request.'}", + "Billing-related error: Please ensure you have enough credits to run this model. Details: An unspecified error occurred while processing your request.", ), ( 485, - "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access. Details: {'error': 'An unspecified error occurred while processing your request.'}", + "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access. Details: An unspecified error occurred while processing your request.", ), ( 495, - "Validation-related error: Please ensure all required fields are provided and correctly formatted. Details: {'error': 'An unspecified error occurred while processing your request.'}", + "An unspecified error occurred while processing your request.", ), - (501, "Status 501 - Unspecified error: {'error': 'An unspecified error occurred while processing your request.'}"), + (501, "Status 501 - Unspecified error: An unspecified error occurred while processing your request."), ], ) def test_run_async_errors(status_code, error_message): diff --git a/tests/unit/model_test.py b/tests/unit/model_test.py index 0907b8f1..03dccdbe 100644 --- a/tests/unit/model_test.py +++ b/tests/unit/model_test.py @@ -114,25 +114,25 @@ def test_failed_poll(): [ ( 401, - "Unauthorized API key: Please verify the spelling of the API key and its current validity. Details: {'error': 'An unspecified error occurred while processing your request.'}", + "Unauthorized API key: Please verify the spelling of the API key and its current validity. Details: An unspecified error occurred while processing your request.", ), ( 465, - "Subscription-related error: Please ensure that your subscription is active and has not expired. Details: {'error': 'An unspecified error occurred while processing your request.'}", + "Subscription-related error: Please ensure that your subscription is active and has not expired. Details: An unspecified error occurred while processing your request.", ), ( 475, - "Billing-related error: Please ensure you have enough credits to run this model. Details: {'error': 'An unspecified error occurred while processing your request.'}", + "Billing-related error: Please ensure you have enough credits to run this model. Details: An unspecified error occurred while processing your request.", ), ( 485, - "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access. Details: {'error': 'An unspecified error occurred while processing your request.'}", + "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access. Details: An unspecified error occurred while processing your request.", ), ( 495, - "Validation-related error: Please ensure all required fields are provided and correctly formatted. Details: {'error': 'An unspecified error occurred while processing your request.'}", + "An unspecified error occurred while processing your request.", ), - (501, "Status 501 - Unspecified error: {'error': 'An unspecified error occurred while processing your request.'}"), + (501, "Status 501 - Unspecified error: An unspecified error occurred while processing your request."), ], ) def test_run_async_errors(status_code, error_message):