diff --git a/aixplain/decorators/api_key_checker.py b/aixplain/decorators/api_key_checker.py index d2611c0e..9fb317cb 100644 --- a/aixplain/decorators/api_key_checker.py +++ b/aixplain/decorators/api_key_checker.py @@ -3,7 +3,7 @@ def check_api_key(method): def wrapper(*args, **kwargs): - if config.TEAM_API_KEY == "": + if config.TEAM_API_KEY == "" and config.AIXPLAIN_API_KEY == "": raise Exception( "A 'TEAM_API_KEY' is required to run an asset. For help, please refer to the documentation (https://github.com/aixplain/aixplain#api-key-setup)" ) diff --git a/aixplain/enums/function.py b/aixplain/enums/function.py index a6d2e40a..67b5eba0 100644 --- a/aixplain/enums/function.py +++ b/aixplain/enums/function.py @@ -31,15 +31,11 @@ def load_functions(): api_key = config.TEAM_API_KEY - aixplain_key = config.AIXPLAIN_API_KEY backend_url = config.BACKEND_URL url = urljoin(backend_url, "sdk/functions") - if aixplain_key != "": - api_key = aixplain_key - headers = {"x-aixplain-key": aixplain_key, "Content-Type": "application/json"} - else: - headers = {"x-api-key": api_key, "Content-Type": "application/json"} + + headers = {"x-api-key": api_key, "Content-Type": "application/json"} r = _request_with_retry("get", url, headers=headers) if not 200 <= r.status_code < 300: raise Exception( @@ -61,5 +57,4 @@ def load_functions(): } return functions, functions_input_output - Function, FunctionInputOutput = load_functions() diff --git a/aixplain/enums/language.py b/aixplain/enums/language.py index 366d45f5..674940ab 100644 --- a/aixplain/enums/language.py +++ b/aixplain/enums/language.py @@ -31,15 +31,11 @@ def load_languages(): api_key = config.TEAM_API_KEY - aixplain_key = config.AIXPLAIN_API_KEY backend_url = config.BACKEND_URL url = urljoin(backend_url, "sdk/languages") - if aixplain_key != "": - api_key = aixplain_key - headers = {"x-aixplain-key": aixplain_key, "Content-Type": "application/json"} - else: - headers = {"x-api-key": api_key, "Content-Type": "application/json"} + + headers = {"x-api-key": api_key, "Content-Type": "application/json"} r = _request_with_retry("get", url, headers=headers) if not 200 <= r.status_code < 300: raise Exception( diff --git a/aixplain/enums/license.py b/aixplain/enums/license.py index 1943ec44..14527829 100644 --- a/aixplain/enums/license.py +++ b/aixplain/enums/license.py @@ -32,15 +32,11 @@ def load_licenses(): try: api_key = config.TEAM_API_KEY - aixplain_key = config.AIXPLAIN_API_KEY backend_url = config.BACKEND_URL url = urljoin(backend_url, "sdk/licenses") - if aixplain_key != "": - api_key = aixplain_key - headers = {"x-aixplain-key": aixplain_key, "Content-Type": "application/json"} - else: - headers = {"x-api-key": api_key, "Content-Type": "application/json"} + + headers = {"x-api-key": api_key, "Content-Type": "application/json"} r = _request_with_retry("get", url, headers=headers) if not 200 <= r.status_code < 300: raise Exception( diff --git a/aixplain/enums/supplier.py b/aixplain/enums/supplier.py index ecc29998..2bca01b1 100644 --- a/aixplain/enums/supplier.py +++ b/aixplain/enums/supplier.py @@ -39,15 +39,11 @@ def clean_name(name): def load_suppliers(): api_key = config.TEAM_API_KEY - aixplain_key = config.AIXPLAIN_API_KEY backend_url = config.BACKEND_URL url = urljoin(backend_url, "sdk/suppliers") - if aixplain_key != "": - api_key = aixplain_key - headers = {"x-aixplain-key": aixplain_key, "Content-Type": "application/json"} - else: - headers = {"x-api-key": api_key, "Content-Type": "application/json"} + + headers = {"x-api-key": api_key, "Content-Type": "application/json"} logging.debug(f"Start service for GET API Creation - {url} - {headers}") r = _request_with_retry("get", url, headers=headers) if not 200 <= r.status_code < 300: diff --git a/aixplain/factories/agent_factory/__init__.py b/aixplain/factories/agent_factory/__init__.py index c56d1fd8..39ae5678 100644 --- a/aixplain/factories/agent_factory/__init__.py +++ b/aixplain/factories/agent_factory/__init__.py @@ -176,11 +176,9 @@ def get(cls, agent_id: Text, api_key: Optional[Text] = None) -> Agent: from aixplain.factories.agent_factory.utils import build_agent url = urljoin(config.BACKEND_URL, f"sdk/agents/{agent_id}") - if config.AIXPLAIN_API_KEY != "": - headers = {"x-aixplain-key": f"{config.AIXPLAIN_API_KEY}", "Content-Type": "application/json"} - else: - api_key = api_key if api_key is not None else config.TEAM_API_KEY - headers = {"x-api-key": api_key, "Content-Type": "application/json"} + + api_key = api_key if api_key is not None else config.TEAM_API_KEY + headers = {"x-api-key": api_key, "Content-Type": "application/json"} logging.info(f"Start service for GET Agent - {url} - {headers}") r = _request_with_retry("get", url, headers=headers) resp = r.json() diff --git a/aixplain/factories/asset_factory.py b/aixplain/factories/asset_factory.py index 460f7cfa..51192b2a 100644 --- a/aixplain/factories/asset_factory.py +++ b/aixplain/factories/asset_factory.py @@ -28,7 +28,7 @@ class AssetFactory: - aixplain_key = config.AIXPLAIN_API_KEY + backend_url = config.BACKEND_URL @abstractmethod diff --git a/aixplain/factories/benchmark_factory.py b/aixplain/factories/benchmark_factory.py index 305fb5d9..743ed7fa 100644 --- a/aixplain/factories/benchmark_factory.py +++ b/aixplain/factories/benchmark_factory.py @@ -43,7 +43,7 @@ class BenchmarkFactory: backend_url (str): The URL for the backend. """ - aixplain_key = config.AIXPLAIN_API_KEY + backend_url = config.BACKEND_URL @classmethod @@ -69,10 +69,8 @@ def _get_benchmark_jobs_from_benchmark_id(cls, benchmark_id: Text) -> List[Bench List[BenchmarkJob]: List of associated benchmark jobs """ url = urljoin(cls.backend_url, f"sdk/benchmarks/{benchmark_id}/jobs") - if cls.aixplain_key != "": - headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} - else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} r = _request_with_retry("get", url, headers=headers) resp = r.json() job_list = [cls._create_benchmark_job_from_response(job_info) for job_info in resp] @@ -107,10 +105,7 @@ def get(cls, benchmark_id: str) -> Benchmark: resp = None try: url = urljoin(cls.backend_url, f"sdk/benchmarks/{benchmark_id}") - if cls.aixplain_key != "": - headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} - else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} logging.info(f"Start service for GET Benchmark - {url} - {headers}") r = _request_with_retry("get", url, headers=headers) resp = r.json() @@ -145,10 +140,7 @@ def get_job(cls, job_id: Text) -> BenchmarkJob: BenchmarkJob: Created 'BenchmarkJob' object """ url = urljoin(cls.backend_url, f"sdk/benchmarks/jobs/{job_id}") - if cls.aixplain_key != "": - headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} - else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} r = _request_with_retry("get", url, headers=headers) resp = r.json() benchmarkJob = cls._create_benchmark_job_from_response(resp) @@ -235,10 +227,7 @@ def list_normalization_options(cls, metric: Metric, model: Model) -> List[str]: """ try: url = urljoin(cls.backend_url, "sdk/benchmarks/normalization-options") - if cls.aixplain_key != "": - headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} - else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} payload = json.dumps({"metricId": metric.id, "modelIds": [model.id]}) r = _request_with_retry("post", url, headers=headers, data=payload) resp = r.json() diff --git a/aixplain/factories/corpus_factory.py b/aixplain/factories/corpus_factory.py index 3b9c5e4b..db7aa44e 100644 --- a/aixplain/factories/corpus_factory.py +++ b/aixplain/factories/corpus_factory.py @@ -48,7 +48,6 @@ class CorpusFactory(AssetFactory): - aixplain_key = config.AIXPLAIN_API_KEY backend_url = config.BACKEND_URL @classmethod @@ -117,10 +116,8 @@ def get(cls, corpus_id: Text) -> Corpus: """ try: url = urljoin(cls.backend_url, f"sdk/corpora/{corpus_id}/overview") - if cls.aixplain_key != "": - headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} - else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} logging.info(f"Start service for GET Corpus - {url} - {headers}") r = _request_with_retry("get", url, headers=headers) resp = r.json() @@ -172,10 +169,8 @@ def list( Dict: list of corpora in agreement with the filters, page number, page total and total elements """ url = urljoin(cls.backend_url, "sdk/corpora/paginate") - if cls.aixplain_key != "": - headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} - else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} assert 0 < page_size <= 100, "Corpus List Error: Page size must be greater than 0 and not exceed 100." payload = {"pageSize": page_size, "pageNumber": page_number, "sort": [{"field": "createdAt", "dir": -1}]} diff --git a/aixplain/factories/data_factory.py b/aixplain/factories/data_factory.py index 3f512aaf..1879b321 100644 --- a/aixplain/factories/data_factory.py +++ b/aixplain/factories/data_factory.py @@ -46,7 +46,6 @@ class DataFactory(AssetFactory): backend_url (str): The URL for the backend. """ - aixplain_key = config.AIXPLAIN_API_KEY backend_url = config.BACKEND_URL @classmethod @@ -92,10 +91,8 @@ def get(cls, data_id: Text) -> Data: Data: Created 'Data' object """ url = urljoin(cls.backend_url, f"sdk/data/{data_id}/overview") - if cls.aixplain_key != "": - headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} - else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} logging.info(f"Start service for GET Data - {url} - {headers}") r = _request_with_retry("get", url, headers=headers) resp = r.json() diff --git a/aixplain/factories/dataset_factory.py b/aixplain/factories/dataset_factory.py index 081513c0..c7ccad70 100644 --- a/aixplain/factories/dataset_factory.py +++ b/aixplain/factories/dataset_factory.py @@ -57,7 +57,7 @@ class DatasetFactory(AssetFactory): backend_url (str): The URL for the backend. """ - aixplain_key = config.AIXPLAIN_API_KEY + backend_url = config.BACKEND_URL @classmethod @@ -164,10 +164,8 @@ def get(cls, dataset_id: Text) -> Dataset: """ try: url = urljoin(cls.backend_url, f"sdk/datasets/{dataset_id}/overview") - if cls.aixplain_key != "": - headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} - else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} logging.info(f"Start service for GET Dataset - {url} - {headers}") r = _request_with_retry("get", url, headers=headers) resp = r.json() @@ -214,10 +212,8 @@ def list( Dict: list of datasets in agreement with the filters, page number, page total and total elements """ url = urljoin(cls.backend_url, "sdk/datasets/paginate") - if cls.aixplain_key != "": - headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} - else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} assert 0 < page_size <= 100, "Dataset List Error: Page size must be greater than 0 and not exceed 100." payload = { diff --git a/aixplain/factories/finetune_factory/__init__.py b/aixplain/factories/finetune_factory/__init__.py index 7a23c527..238d0d0c 100644 --- a/aixplain/factories/finetune_factory/__init__.py +++ b/aixplain/factories/finetune_factory/__init__.py @@ -44,7 +44,6 @@ class FinetuneFactory: backend_url (str): The URL for the backend. """ - aixplain_key = config.AIXPLAIN_API_KEY backend_url = config.BACKEND_URL @classmethod diff --git a/aixplain/factories/metric_factory.py b/aixplain/factories/metric_factory.py index a0372827..9f42fb3e 100644 --- a/aixplain/factories/metric_factory.py +++ b/aixplain/factories/metric_factory.py @@ -39,7 +39,6 @@ class MetricFactory: backend_url (str): The URL for the backend. """ - aixplain_key = config.AIXPLAIN_API_KEY backend_url = config.BACKEND_URL @classmethod @@ -76,10 +75,7 @@ def get(cls, metric_id: Text) -> Metric: resp, status_code = None, 200 try: url = urljoin(cls.backend_url, f"sdk/metrics/{metric_id}") - if cls.aixplain_key != "": - headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} - else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} logging.info(f"Start service for GET Metric - {url} - {headers}") r = _request_with_retry("get", url, headers=headers) resp = r.json() @@ -126,10 +122,7 @@ def list( if is_reference_required is not None: filter_params["referenceRequired"] = 1 if is_reference_required else 0 - if cls.aixplain_key != "": - headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} - else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} r = _request_with_retry("get", url, headers=headers, params=filter_params) resp = r.json() logging.info(f"Listing Metrics: Status of getting metrics: {resp}") diff --git a/aixplain/factories/model_factory.py b/aixplain/factories/model_factory.py index b6588023..052750a7 100644 --- a/aixplain/factories/model_factory.py +++ b/aixplain/factories/model_factory.py @@ -41,7 +41,6 @@ class ModelFactory: backend_url (str): The URL for the backend. """ - aixplain_key = config.AIXPLAIN_API_KEY backend_url = config.BACKEND_URL @classmethod @@ -107,10 +106,8 @@ def get(cls, model_id: Text, api_key: Optional[Text] = None) -> Model: resp = None try: url = urljoin(cls.backend_url, f"sdk/models/{model_id}") - if cls.aixplain_key != "": - headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} - else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} logging.info(f"Start service for GET Model - {url} - {headers}") r = _request_with_retry("get", url, headers=headers) resp = r.json() @@ -196,10 +193,8 @@ def _get_assets_from_page( filter_params["sort"] = [{"dir": sort_order.value, "field": sort_by.value}] if len(lang_filter_params) != 0: filter_params["ioFilter"] = lang_filter_params - if cls.aixplain_key != "": - headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} - else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} logging.info(f"Start service for POST Models Paginate - {url} - {headers} - {json.dumps(filter_params)}") r = _request_with_retry("post", url, headers=headers, json=filter_params) diff --git a/aixplain/factories/pipeline_factory/__init__.py b/aixplain/factories/pipeline_factory/__init__.py index ef330de0..f960d6da 100644 --- a/aixplain/factories/pipeline_factory/__init__.py +++ b/aixplain/factories/pipeline_factory/__init__.py @@ -43,7 +43,6 @@ class PipelineFactory: backend_url (str): The URL for the backend. """ - aixplain_key = config.AIXPLAIN_API_KEY backend_url = config.BACKEND_URL @classmethod @@ -65,11 +64,6 @@ def get(cls, pipeline_id: Text, api_key: Optional[Text] = None) -> Pipeline: "Authorization": f"Token {api_key}", "Content-Type": "application/json", } - elif cls.aixplain_key != "": - headers = { - "x-aixplain-key": f"{cls.aixplain_key}", - "Content-Type": "application/json", - } else: headers = { "Authorization": f"Token {config.TEAM_API_KEY}", @@ -125,13 +119,8 @@ def get_assets_from_page(cls, page_number: int) -> List[Pipeline]: """ try: url = urljoin(cls.backend_url, f"sdk/pipelines/?pageNumber={page_number}") - if cls.aixplain_key != "": - headers = { - "x-aixplain-key": f"{cls.aixplain_key}", - "Content-Type": "application/json", - } - else: - headers = { + + headers = { "Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json", } @@ -181,13 +170,8 @@ def list( ) -> Dict: url = urljoin(cls.backend_url, "sdk/pipelines/paginate") - if cls.aixplain_key != "": - headers = { - "x-aixplain-key": f"{cls.aixplain_key}", - "Content-Type": "application/json", - } - else: - headers = { + + headers = { "Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json", } diff --git a/aixplain/factories/team_agent_factory/__init__.py b/aixplain/factories/team_agent_factory/__init__.py index 3f65b4b0..0819d989 100644 --- a/aixplain/factories/team_agent_factory/__init__.py +++ b/aixplain/factories/team_agent_factory/__init__.py @@ -153,11 +153,8 @@ def list(cls) -> Dict: def get(cls, agent_id: Text, api_key: Optional[Text] = None) -> TeamAgent: """Get agent by id.""" url = urljoin(config.BACKEND_URL, f"sdk/agent-communities/{agent_id}") - if config.AIXPLAIN_API_KEY != "": - headers = {"x-aixplain-key": f"{config.AIXPLAIN_API_KEY}", "Content-Type": "application/json"} - else: - api_key = api_key if api_key is not None else config.TEAM_API_KEY - headers = {"x-api-key": api_key, "Content-Type": "application/json"} + api_key = api_key if api_key is not None else config.TEAM_API_KEY + headers = {"x-api-key": api_key, "Content-Type": "application/json"} logging.info(f"Start service for GET Team Agent - {url} - {headers}") try: r = _request_with_retry("get", url, headers=headers) diff --git a/aixplain/factories/wallet_factory.py b/aixplain/factories/wallet_factory.py index 01c0ac2e..1591dc2e 100644 --- a/aixplain/factories/wallet_factory.py +++ b/aixplain/factories/wallet_factory.py @@ -6,7 +6,6 @@ class WalletFactory: - aixplain_key = config.AIXPLAIN_API_KEY backend_url = config.BACKEND_URL @classmethod diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index b7aad7aa..c436b84a 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -36,6 +36,7 @@ from aixplain.modules.agent.tool import Tool from aixplain.modules.agent.tool.model_tool import ModelTool from aixplain.modules.agent.tool.pipeline_tool import PipelineTool +from aixplain.modules.agent.utils import process_variables from typing import Dict, List, Text, Optional, Union from urllib.parse import urljoin @@ -238,9 +239,12 @@ def run_async( headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} + # build query + input_data = process_variables(query, data, parameters, self.description) + payload = { "id": self.id, - "query": FileFactory.to_link(query), + "query": input_data, "sessionId": session_id, "history": history, "executionParams": { diff --git a/aixplain/modules/agent/utils.py b/aixplain/modules/agent/utils.py new file mode 100644 index 00000000..03de61d1 --- /dev/null +++ b/aixplain/modules/agent/utils.py @@ -0,0 +1,22 @@ +from typing import Dict, Text, Union +import re + + +def process_variables(query: Text, data: Union[Dict, Text], parameters: Dict, agent_description: Text) -> Text: + from aixplain.factories.file_factory import FileFactory + + variables = re.findall(r"(? dict: url = urljoin(config.BACKEND_URL, f"sdk/benchmarks/jobs/{job_id}") - if config.AIXPLAIN_API_KEY != "": - headers = {"x-aixplain-key": f"{config.AIXPLAIN_API_KEY}", "Content-Type": "application/json"} - else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} r = _request_with_retry("get", url, headers=headers) resp = r.json() return resp diff --git a/aixplain/modules/model/response.py b/aixplain/modules/model/response.py index 94ddcb9d..99e32074 100644 --- a/aixplain/modules/model/response.py +++ b/aixplain/modules/model/response.py @@ -40,8 +40,11 @@ def __getitem__(self, key: Text) -> Any: return self.run_time raise KeyError(f"Key '{key}' not found in ModelResponse.") - def get(self, key: Text) -> Any: - return self[key] + def get(self, key: Text, default: Optional[Any] = None) -> Any: + try: + return self[key] + except KeyError: + return default def __repr__(self) -> str: fields = [] diff --git a/aixplain/modules/pipeline/generate.py b/aixplain/modules/pipeline/generate.py index 46c95482..8bfeecb3 100644 --- a/aixplain/modules/pipeline/generate.py +++ b/aixplain/modules/pipeline/generate.py @@ -103,7 +103,7 @@ def fetch_functions(): Fetch functions from the backend """ api_key = config.TEAM_API_KEY - aixplain_key = config.AIXPLAIN_API_KEY + backend_url = config.BACKEND_URL url = urljoin(backend_url, "sdk/functions") @@ -111,10 +111,7 @@ def fetch_functions(): "Content-Type": "application/json", } - if aixplain_key: - headers["x-aixplain-key"] = aixplain_key - else: - headers["x-api-key"] = api_key + headers["x-api-key"] = api_key r = requests.get(url, headers=headers) try: diff --git a/aixplain/modules/team_agent/__init__.py b/aixplain/modules/team_agent/__init__.py index 08d820f0..f92b437d 100644 --- a/aixplain/modules/team_agent/__init__.py +++ b/aixplain/modules/team_agent/__init__.py @@ -34,6 +34,7 @@ from aixplain.enums.storage_type import StorageType from aixplain.modules.model import Model from aixplain.modules.agent import Agent, OutputFormat +from aixplain.modules.agent.utils import process_variables from typing import Dict, List, Text, Optional, Union from urllib.parse import urljoin @@ -222,9 +223,12 @@ def run_async( headers = {"x-api-key": self.api_key, "Content-Type": "application/json"} + # build query + input_data = process_variables(query, data, parameters, self.description) + payload = { "id": self.id, - "query": FileFactory.to_link(query), + "query": input_data, "sessionId": session_id, "history": history, "executionParams": { diff --git a/aixplain/utils/config.py b/aixplain/utils/config.py index 03bbdccf..b47bc4f7 100644 --- a/aixplain/utils/config.py +++ b/aixplain/utils/config.py @@ -23,6 +23,15 @@ # GET THE API KEY FROM CMD TEAM_API_KEY = os.getenv("TEAM_API_KEY", "") AIXPLAIN_API_KEY = os.getenv("AIXPLAIN_API_KEY", "") + +if AIXPLAIN_API_KEY and TEAM_API_KEY: + if AIXPLAIN_API_KEY != TEAM_API_KEY: + raise Exception("Conflicting API keys: 'AIXPLAIN_API_KEY' and 'TEAM_API_KEY' are both provided but do not match. Please provide only one API key.") + + +if AIXPLAIN_API_KEY and not TEAM_API_KEY: + TEAM_API_KEY = AIXPLAIN_API_KEY + PIPELINE_API_KEY = os.getenv("PIPELINE_API_KEY", "") MODEL_API_KEY = os.getenv("MODEL_API_KEY", "") LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO") diff --git a/aixplain/utils/file_utils.py b/aixplain/utils/file_utils.py index 02ddc5ef..0e617397 100644 --- a/aixplain/utils/file_utils.py +++ b/aixplain/utils/file_utils.py @@ -122,12 +122,8 @@ def upload_data( tags = [] payload = {"contentType": content_type, "originalName": file_name, "tags": ",".join(tags), "license": license.value} - if config.AIXPLAIN_API_KEY != "": - team_key = config.AIXPLAIN_API_KEY - headers = {"x-aixplain-key": team_key} - else: - team_key = config.TEAM_API_KEY - headers = {"Authorization": "token " + team_key} + team_key = config.TEAM_API_KEY + headers = {"Authorization": "token " + team_key} r = _request_with_retry("post", url, headers=headers, data=payload) response = r.json() diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 55d671e0..3f54d470 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -163,3 +163,46 @@ def test_delete_agent_in_use(delete_agents_and_team_agents): with pytest.raises(Exception) as exc_info: agent.delete() assert str(exc_info.value) == "Agent Deletion Error (HTTP 403): err.agent_is_in_use." + + +def test_update_tools_of_agent(run_input_map, delete_agents_and_team_agents): + assert delete_agents_and_team_agents + + agent = AgentFactory.create( + name=run_input_map["agent_name"], description=run_input_map["agent_name"], llm_id=run_input_map["llm_id"] + ) + assert agent is not None + assert agent.status == AssetStatus.DRAFT + assert len(agent.tools) == 0 + + tools = [] + if "model_tools" in run_input_map: + for tool in run_input_map["model_tools"]: + tool_ = copy.copy(tool) + for supplier in Supplier: + if tool["supplier"] is not None and tool["supplier"].lower() in [ + supplier.value["code"].lower(), + supplier.value["name"].lower(), + ]: + tool_["supplier"] = supplier + break + tools.append(AgentFactory.create_model_tool(**tool_)) + + if "pipeline_tools" in run_input_map: + for tool in run_input_map["pipeline_tools"]: + tools.append(AgentFactory.create_pipeline_tool(pipeline=tool["pipeline_id"], description=tool["description"])) + + agent.tools = tools + agent.update() + + agent = AgentFactory.get(agent.id) + assert len(agent.tools) == len(tools) + + removed_tool = agent.tools.pop() + agent.update() + + agent = AgentFactory.get(agent.id) + assert len(agent.tools) == len(tools) - 1 + assert removed_tool not in agent.tools + + agent.delete() diff --git a/tests/functional/team_agent/team_agent_functional_test.py b/tests/functional/team_agent/team_agent_functional_test.py index 44ea5dbc..e60e453a 100644 --- a/tests/functional/team_agent/team_agent_functional_test.py +++ b/tests/functional/team_agent/team_agent_functional_test.py @@ -160,3 +160,62 @@ def test_fail_non_existent_llm(): tools=[AgentFactory.create_model_tool(function=Function.TRANSLATION)], ) assert str(exc_info.value) == "Large Language Model with ID 'non_existent_llm' not found." + +def test_add_remove_agents_from_team_agent(run_input_map, delete_agents_and_team_agents): + assert delete_agents_and_team_agents + + agents = [] + for agent in run_input_map["agents"]: + tools = [] + if "model_tools" in agent: + for tool in agent["model_tools"]: + tool_ = copy(tool) + for supplier in Supplier: + if tool["supplier"] is not None and tool["supplier"].lower() in [ + supplier.value["code"].lower(), + supplier.value["name"].lower(), + ]: + tool_["supplier"] = supplier + break + tools.append(AgentFactory.create_model_tool(**tool_)) + if "pipeline_tools" in agent: + for tool in agent["pipeline_tools"]: + tools.append(AgentFactory.create_pipeline_tool(pipeline=tool["pipeline_id"], description=tool["description"])) + + agent = AgentFactory.create( + name=agent["agent_name"], description=agent["agent_name"], llm_id=agent["llm_id"], tools=tools + ) + agents.append(agent) + + team_agent = TeamAgentFactory.create( + name=run_input_map["team_agent_name"], + agents=agents, + description=run_input_map["team_agent_name"], + llm_id=run_input_map["llm_id"], + use_mentalist_and_inspector=True, + ) + + assert team_agent is not None + assert team_agent.status == AssetStatus.DRAFT + + new_agent = AgentFactory.create( + name="New Agent", + description="Agent added to team", + llm_id=run_input_map["llm_id"], + ) + team_agent.agents.append(new_agent) + team_agent.update() + + team_agent = TeamAgentFactory.get(team_agent.id) + assert new_agent.id in [agent.id for agent in team_agent.agents] + assert len(team_agent.agents) == len(agents) + 1 + + removed_agent = team_agent.agents.pop(0) + team_agent.update() + + team_agent = TeamAgentFactory.get(team_agent.id) + assert removed_agent.id not in [agent.id for agent in team_agent.agents] + assert len(team_agent.agents) == len(agents) + + team_agent.delete() + new_agent.delete() diff --git a/tests/unit/agent_test.py b/tests/unit/agent_test.py index 9e38937f..ce1eac63 100644 --- a/tests/unit/agent_test.py +++ b/tests/unit/agent_test.py @@ -6,6 +6,7 @@ from aixplain.utils import config from aixplain.factories import AgentFactory from aixplain.modules.agent import PipelineTool, ModelTool +from aixplain.modules.agent.utils import process_variables from urllib.parse import urljoin @@ -234,7 +235,7 @@ def test_run_success(): url = urljoin(config.BACKEND_URL, f"sdk/agents/{agent.id}/run") agent.url = url with requests_mock.Mocker() as mock: - headers = {"x-aixplain-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + headers = {"x-api-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} ref_response = {"data": "www.aixplain.com", "status": "IN_PROGRESS"} mock.post(url, headers=headers, json=ref_response) @@ -244,3 +245,23 @@ def test_run_success(): ) assert response["status"] == "IN_PROGRESS" assert response["url"] == ref_response["data"] + + +def test_run_variable_error(): + agent = Agent("123", "Test Agent", "Translate the input data into {target_language}") + with pytest.raises(Exception) as exc_info: + agent.run_async(data={"query": "Hello, how are you?"}, output_format=OutputFormat.MARKDOWN) + assert ( + str(exc_info.value) + == "Variable 'target_language' not found in data or parameters. This variable is required by the agent according to its description ('Translate the input data into {target_language}')." + ) + + +def test_process_variables(): + query = "Hello, how are you?" + data = {"target_language": "English"} + agent_description = "Translate the input data into {target_language}" + assert process_variables(query=query, data=data, parameters={}, agent_description=agent_description) == { + "input": "Hello, how are you?", + "target_language": "English", + } diff --git a/tests/unit/benchmark_test.py b/tests/unit/benchmark_test.py index 167e4bcb..08a91ea3 100644 --- a/tests/unit/benchmark_test.py +++ b/tests/unit/benchmark_test.py @@ -42,7 +42,7 @@ def test_get_benchmark_error(): with requests_mock.Mocker() as mock: benchmark_id = "test-benchmark-id" url = urljoin(config.BACKEND_URL, f"sdk/benchmarks/{benchmark_id}") - headers = {"x-aixplain-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + headers = {"Authorization": f"Token {config.AIXPLAIN_API_KEY}", "Content-Type": "application/json"} error_response = {"statusCode": 404, "message": "Benchmark not found"} mock.get(url, headers=headers, json=error_response, status_code=404) @@ -59,7 +59,7 @@ def test_list_normalization_options_error(): model = Model(id="model1", name="Test Model", description="Test model", supplier="Test supplier", cost=10, version="v1") url = urljoin(config.BACKEND_URL, "sdk/benchmarks/normalization-options") - headers = {"x-aixplain-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + headers = {"Authorization": f"Token {config.AIXPLAIN_API_KEY}", "Content-Type": "application/json"} error_response = {"message": "Internal Server Error"} mock.post(url, headers=headers, json=error_response, status_code=500) diff --git a/tests/unit/corpus_test.py b/tests/unit/corpus_test.py index 07522c4d..bc240382 100644 --- a/tests/unit/corpus_test.py +++ b/tests/unit/corpus_test.py @@ -9,7 +9,7 @@ def test_get_corpus_error_response(): with requests_mock.Mocker() as mock: corpus_id = "invalid_corpus_id" url = urljoin(config.BACKEND_URL, f"sdk/corpora/{corpus_id}/overview") - headers = {"x-aixplain-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + headers = {"Authorization": f"Token {config.AIXPLAIN_API_KEY}", "Content-Type": "application/json"} error_response = {"message": "Not Found"} mock.get(url, headers=headers, json=error_response, status_code=404) @@ -23,7 +23,7 @@ def test_get_corpus_error_response(): def test_list_corpus_error_response(): with requests_mock.Mocker() as mock: url = urljoin(config.BACKEND_URL, "sdk/corpora/paginate") - headers = {"x-aixplain-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + headers = {"Authorization": f"Token {config.AIXPLAIN_API_KEY}", "Content-Type": "application/json"} error_response = {"message": "Internal Server Error"} mock.post(url, headers=headers, json=error_response, status_code=500) diff --git a/tests/unit/dataset_test.py b/tests/unit/dataset_test.py index 25c57123..721a405c 100644 --- a/tests/unit/dataset_test.py +++ b/tests/unit/dataset_test.py @@ -8,7 +8,7 @@ def test_list_dataset_error_response(): with requests_mock.Mocker() as mock: url = urljoin(config.BACKEND_URL, "sdk/datasets/paginate") - headers = {"x-aixplain-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + headers = {"Authorization": f"Token {config.AIXPLAIN_API_KEY}", "Content-Type": "application/json"} error_response = {"message": "Internal Server Error"} mock.post(url, headers=headers, json=error_response, status_code=500) @@ -23,7 +23,7 @@ def test_get_dataset_error_response(): with requests_mock.Mocker() as mock: dataset_id = "invalid_dataset_id" url = urljoin(config.BACKEND_URL, f"sdk/datasets/{dataset_id}/overview") - headers = {"x-aixplain-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + headers = {"Authorization": f"Token {config.AIXPLAIN_API_KEY}", "Content-Type": "application/json"} error_response = {"message": "Not Found"} mock.get(url, headers=headers, json=error_response, status_code=404) diff --git a/tests/unit/model_test.py b/tests/unit/model_test.py index b45b6ae0..452d9ac5 100644 --- a/tests/unit/model_test.py +++ b/tests/unit/model_test.py @@ -87,7 +87,7 @@ def test_success_poll(): hyp_response = test_model.poll(poll_url=poll_url) assert isinstance(hyp_response, ModelResponse) assert hyp_response["completed"] == ref_response["completed"] - assert hyp_response["status"] == ResponseStatus.SUCCESS + assert hyp_response.get("status") == ResponseStatus.SUCCESS def test_failed_poll(): @@ -152,7 +152,7 @@ def test_get_model_error_response(): with requests_mock.Mocker() as mock: model_id = "test-model-id" url = urljoin(config.BACKEND_URL, f"sdk/models/{model_id}") - headers = {"x-aixplain-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + headers = {"Authorization": f"Token {config.AIXPLAIN_API_KEY}", "Content-Type": "application/json"} error_response = {"statusCode": 404, "message": "Model not found"} mock.get(url, headers=headers, json=error_response, status_code=404) @@ -169,7 +169,7 @@ def test_get_assets_from_page_error(): page_number = 0 page_size = 2 url = urljoin(config.BACKEND_URL, "sdk/models/paginate") - headers = {"x-aixplain-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + headers = {"Authorization": f"Token {config.AIXPLAIN_API_KEY}", "Content-Type": "application/json"} error_response = {"statusCode": 500, "message": "Internal Server Error"} mock.post(url, headers=headers, json=error_response, status_code=500) diff --git a/tests/unit/pipeline_test.py b/tests/unit/pipeline_test.py index 05ee7172..d1b0f9b2 100644 --- a/tests/unit/pipeline_test.py +++ b/tests/unit/pipeline_test.py @@ -72,7 +72,7 @@ def test_list_pipelines_error_response(): page_number = 0 page_size = 20 url = urljoin(config.BACKEND_URL, "sdk/pipelines/paginate") - headers = {"x-aixplain-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + headers = {"Authorization": f"Token {config.AIXPLAIN_API_KEY}", "Content-Type": "application/json"} error_response = {"statusCode": 400, "message": "Bad Request"} mock.post(url, headers=headers, json=error_response, status_code=400)