From c11c4028596ec3bb4eac521f6146e1baf51ed2d1 Mon Sep 17 00:00:00 2001 From: xainaz Date: Thu, 10 Oct 2024 14:16:22 +0300 Subject: [PATCH 1/4] Improve error log for: Benchmark, Corpus, Dataset, Model, Pipeline --- aixplain/factories/benchmark_factory.py | 49 ++++-- aixplain/factories/corpus_factory.py | 139 +++++++++------- aixplain/factories/dataset_factory.py | 154 ++++++++++-------- aixplain/factories/model_factory.py | 33 ++-- .../factories/pipeline_factory/__init__.py | 64 +++++--- 5 files changed, 261 insertions(+), 178 deletions(-) diff --git a/aixplain/factories/benchmark_factory.py b/aixplain/factories/benchmark_factory.py index 57d4a833..ea983075 100644 --- a/aixplain/factories/benchmark_factory.py +++ b/aixplain/factories/benchmark_factory.py @@ -22,10 +22,8 @@ """ import logging -from typing import Dict, List, Optional, Text +from typing import Dict, List, Text import json -import pandas as pd -from pathlib import Path from aixplain.enums.supplier import Supplier from aixplain.modules import Dataset, Metric, Model from aixplain.modules.benchmark_job import BenchmarkJob @@ -34,9 +32,8 @@ from aixplain.factories.dataset_factory import DatasetFactory from aixplain.factories.model_factory import ModelFactory from aixplain.utils import config -from aixplain.utils.file_utils import _request_with_retry, save_file +from aixplain.utils.file_utils import _request_with_retry from urllib.parse import urljoin -from warnings import warn class BenchmarkFactory: @@ -117,7 +114,14 @@ def get(cls, benchmark_id: str) -> Benchmark: logging.info(f"Start service for GET Benchmark - {url} - {headers}") r = _request_with_retry("get", url, headers=headers) resp = r.json() - benchmark = cls._create_benchmark_from_response(resp) + if 200 <= r.status_code < 300: + benchmark = cls._create_benchmark_from_response(resp) + logging.info(f"Benchmark {benchmark_id} retrieved successfully.") + return benchmark + else: + error_message = f"Benchmark GET Error: Status {r.status_code} - {resp.get('message', 'No message')}" + logging.error(error_message) + raise Exception(error_message) except Exception as e: status_code = 400 if resp is not None and "statusCode" in resp: @@ -125,10 +129,9 @@ def get(cls, benchmark_id: str) -> Benchmark: message = resp["message"] message = f"Benchmark Creation: Status {status_code} - {message}" else: - message = f"Benchmark Creation: Unspecified Error" + message = "Benchmark Creation: Unspecified Error" logging.error(f"Benchmark Creation Failed: {e}") raise Exception(f"Status {status_code}: {message}") - return benchmark @classmethod def get_job(cls, job_id: Text) -> BenchmarkJob: @@ -189,7 +192,7 @@ def create(cls, name: str, dataset_list: List[Dataset], model_list: List[Model], """ payload = {} try: - url = urljoin(cls.backend_url, f"sdk/benchmarks") + url = urljoin(cls.backend_url, "sdk/benchmarks") headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} payload = { "name": name, @@ -204,8 +207,13 @@ def create(cls, name: str, dataset_list: List[Dataset], model_list: List[Model], payload = json.dumps(clean_payload) r = _request_with_retry("post", url, headers=headers, data=payload) resp = r.json() - logging.info(f"Creating Benchmark Job: Status for {name}: {resp}") - return cls.get(resp["id"]) + if 200 <= r.status_code < 300: + logging.info(f"Benchmark {name} created successfully: {resp}") + return cls.get(resp["id"]) + else: + error_message = f"Benchmark Creation Error: Status {r.status_code} - {resp.get('message', 'No message')}" + logging.error(error_message) + raise Exception(error_message) except Exception as e: error_message = f"Creating Benchmark Job: Error in Creating Benchmark with payload {payload} : {e}" logging.error(error_message, exc_info=True) @@ -223,7 +231,7 @@ def list_normalization_options(cls, metric: Metric, model: Model) -> List[str]: List[str]: List of supported normalization options """ try: - url = urljoin(cls.backend_url, f"sdk/benchmarks/normalization-options") + url = urljoin(cls.backend_url, "sdk/benchmarks/normalization-options") if cls.aixplain_key != "": headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} else: @@ -231,9 +239,17 @@ def list_normalization_options(cls, metric: Metric, model: Model) -> List[str]: payload = json.dumps({"metricId": metric.id, "modelIds": [model.id]}) r = _request_with_retry("post", url, headers=headers, data=payload) resp = r.json() - logging.info(f"Listing Normalization Options: Status of listing options: {resp}") - normalization_options = [item["value"] for item in resp] - return normalization_options + + if 200 <= r.status_code < 300: + logging.info(f"Listing Normalization Options: Status of listing options: {resp}") + normalization_options = [item["value"] for item in resp] + return normalization_options + else: + error_message = ( + f"Error listing normalization options: Status {r.status_code} - {resp.get('message', 'No message')}" + ) + logging.error(error_message) + return [] except Exception as e: error_message = f"Listing Normalization Options: Error in getting Normalization Options: {e}" logging.error(error_message, exc_info=True) @@ -255,7 +271,8 @@ def __get_model_name(model_id): if model.version is not None: name = f"{name}({model.version})" return name + benchmarkJob = cls.get_job(job_id) scores_df = benchmarkJob.get_scores() scores_df["Model"] = scores_df["Model"].apply(lambda x: __get_model_name(x)) - return scores_df \ No newline at end of file + return scores_df diff --git a/aixplain/factories/corpus_factory.py b/aixplain/factories/corpus_factory.py index 1f81ac4d..59333ffd 100644 --- a/aixplain/factories/corpus_factory.py +++ b/aixplain/factories/corpus_factory.py @@ -21,7 +21,6 @@ Corpus Factory Class """ -import aixplain.utils.config as config import aixplain.processes.data_onboarding.onboard_functions as onboard_functions import json import logging @@ -86,12 +85,12 @@ def __from_response(cls, response: Dict) -> Corpus: try: license = License(response["license"]["typeId"]) - except: + except Exception: license = None try: length = int(response["segmentsCount"]) - except: + except Exception: length = None corpus = Corpus( @@ -116,17 +115,26 @@ def get(cls, corpus_id: Text) -> Corpus: Returns: Corpus: Created 'Corpus' object """ - url = urljoin(cls.backend_url, f"sdk/corpora/{corpus_id}/overview") - if cls.aixplain_key != "": - headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} - else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} - logging.info(f"Start service for GET Corpus - {url} - {headers}") - r = _request_with_retry("get", url, headers=headers) - resp = r.json() - if "statusCode" in resp and resp["statusCode"] == 404: - raise Exception(f"Corpus GET Error: Dataset {corpus_id} not found.") - return cls.__from_response(resp) + try: + url = urljoin(cls.backend_url, f"sdk/corpora/{corpus_id}/overview") + if cls.aixplain_key != "": + headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} + else: + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + logging.info(f"Start service for GET Corpus - {url} - {headers}") + r = _request_with_retry("get", url, headers=headers) + resp = r.json() + if 200 <= r.status_code < 300: + logging.info(f"Corpus {corpus_id} retrieved successfully.") + return cls.__from_response(resp) + else: + error_message = f"Corpus GET Error: Status {r.status_code} - {resp.get('message', 'No message')}" + logging.error(error_message) + raise Exception(error_message) + except Exception as e: + error_message = f"Error retrieving Corpus {corpus_id}: {str(e)}" + logging.error(error_message, exc_info=True) + raise Exception(error_message) @classmethod def create_asset_from_id(cls, corpus_id: Text) -> Corpus: @@ -162,52 +170,63 @@ def list( Returns: Dict: list of corpora in agreement with the filters, page number, page total and total elements """ - url = urljoin(cls.backend_url, "sdk/corpora/paginate") - if cls.aixplain_key != "": - headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} - else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} - - assert 0 < page_size <= 100, f"Corpus List Error: Page size must be greater than 0 and not exceed 100." - payload = {"pageSize": page_size, "pageNumber": page_number, "sort": [{"field": "createdAt", "dir": -1}]} - - if query is not None: - payload["q"] = str(query) - - if function is not None: - payload["function"] = function.value - - if license is not None: - payload["license"] = license.value - - if data_type is not None: - payload["dataType"] = data_type.value + try: + url = urljoin(cls.backend_url, "sdk/corpora/paginate") + if cls.aixplain_key != "": + headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} + else: + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + + assert 0 < page_size <= 100, "Corpus List Error: Page size must be greater than 0 and not exceed 100." + payload = {"pageSize": page_size, "pageNumber": page_number, "sort": [{"field": "createdAt", "dir": -1}]} + + if query is not None: + payload["q"] = str(query) + + if function is not None: + payload["function"] = function.value + + if license is not None: + payload["license"] = license.value + + if data_type is not None: + payload["dataType"] = data_type.value + + if language is not None: + if isinstance(language, Language): + language = [language] + payload["language"] = [lng.value["language"] for lng in language] + + logging.info(f"Start service for POST List Corpus - {url} - {headers} - {json.dumps(payload)}") + r = _request_with_retry("post", url, headers=headers, json=payload) + resp = r.json() + if 200 <= r.status_code < 300: + corpora, page_total, total = [], 0, 0 + if "results" in resp: + results = resp["results"] + page_total = resp["pageTotal"] + total = resp["total"] + logging.info(f"Response for POST List Corpus - Page Total: {page_total} / Total: {total}") + for corpus in results: + corpus_ = cls.__from_response(corpus) + # add languages + languages = [] + for lng in corpus["languages"]: + if "dialect" not in lng: + lng["dialect"] = "" + languages.append(Language(lng)) + corpus_.kwargs["languages"] = languages + corpora.append(corpus_) + return {"results": corpora, "page_total": page_total, "page_number": page_number, "total": total} + else: + error_message = f"Corpus List Error: Status {r.status_code} - {resp.get('message', 'No message')}" + logging.error(error_message) + raise Exception(error_message) - if language is not None: - if isinstance(language, Language): - language = [language] - payload["language"] = [lng.value["language"] for lng in language] - - logging.info(f"Start service for POST List Corpus - {url} - {headers} - {json.dumps(payload)}") - r = _request_with_retry("post", url, headers=headers, json=payload) - resp = r.json() - corpora, page_total, total = [], 0, 0 - if "results" in resp: - results = resp["results"] - page_total = resp["pageTotal"] - total = resp["total"] - logging.info(f"Response for POST List Corpus - Page Total: {page_total} / Total: {total}") - for corpus in results: - corpus_ = cls.__from_response(corpus) - # add languages - languages = [] - for lng in corpus["languages"]: - if "dialect" not in lng: - lng["dialect"] = "" - languages.append(Language(lng)) - corpus_.kwargs["languages"] = languages - corpora.append(corpus_) - return {"results": corpora, "page_total": page_total, "page_number": page_number, "total": total} + except Exception as e: + error_message = f"Error listing corpora: {str(e)}" + logging.error(error_message, exc_info=True) + raise Exception(error_message) @classmethod def get_assets_from_page( @@ -245,7 +264,7 @@ def create( functions: List[Function] = [], privacy: Privacy = Privacy.PRIVATE, error_handler: ErrorHandler = ErrorHandler.SKIP, - api_key: Optional[Text] = None + api_key: Optional[Text] = None, ) -> Dict: """Asynchronous call to Upload a corpus to the user's dashboard. diff --git a/aixplain/factories/dataset_factory.py b/aixplain/factories/dataset_factory.py index 5e69d572..4b486cf0 100644 --- a/aixplain/factories/dataset_factory.py +++ b/aixplain/factories/dataset_factory.py @@ -21,7 +21,6 @@ Dataset Factory Class """ -import aixplain.utils.config as config import aixplain.processes.data_onboarding.onboard_functions as onboard_functions import json import os @@ -49,7 +48,6 @@ from typing import Any, Dict, List, Optional, Text, Union from urllib.parse import urljoin from uuid import uuid4 -from warnings import warn class DatasetFactory(AssetFactory): @@ -122,7 +120,7 @@ def __from_response(cls, response: Dict) -> Dataset: target_data_list = [data[data_id] for data_id in out["dataIds"]] data_name = target_data_list[0].name target_data[data_name] = target_data_list - except: + except Exception: pass # process function @@ -164,17 +162,26 @@ def get(cls, dataset_id: Text) -> Dataset: Returns: Dataset: Created 'Dataset' object """ - url = urljoin(cls.backend_url, f"sdk/datasets/{dataset_id}/overview") - if cls.aixplain_key != "": - headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} - else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} - logging.info(f"Start service for GET Dataset - {url} - {headers}") - r = _request_with_retry("get", url, headers=headers) - resp = r.json() - if "statusCode" in resp and resp["statusCode"] == 404: - raise Exception(f"Dataset GET Error: Dataset {dataset_id} not found.") - return cls.__from_response(resp) + try: + url = urljoin(cls.backend_url, f"sdk/datasets/{dataset_id}/overview") + if cls.aixplain_key != "": + headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} + else: + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + logging.info(f"Start service for GET Dataset - {url} - {headers}") + r = _request_with_retry("get", url, headers=headers) + resp = r.json() + if 200 <= r.status_code < 300: + logging.info(f"Dataset {dataset_id} retrieved successfully.") + return cls.__from_response(resp) + else: + error_message = f"Dataset GET Error: Status {r.status_code} - {resp.get('message', 'No message')}" + logging.error(error_message) + raise Exception(error_message) + except Exception as e: + error_message = f"Error retrieving Dataset {dataset_id}: {str(e)}" + logging.error(error_message, exc_info=True) + raise Exception(error_message) @classmethod def list( @@ -205,59 +212,70 @@ def list( Returns: Dict: list of datasets in agreement with the filters, page number, page total and total elements """ - url = urljoin(cls.backend_url, "sdk/datasets/paginate") - if cls.aixplain_key != "": - headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} - else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} - - assert 0 < page_size <= 100, f"Dataset List Error: Page size must be greater than 0 and not exceed 100." - payload = { - "pageSize": page_size, - "pageNumber": page_number, - "sort": [{"field": "createdAt", "dir": -1}], - "input": {}, - "output": {}, - } - - if query is not None: - payload["q"] = str(query) - - if function is not None: - payload["function"] = function.value - - if license is not None: - payload["license"] = license.value - - if data_type is not None: - payload["dataType"] = data_type.value - - if is_referenceless is not None: - payload["isReferenceless"] = is_referenceless - - if source_languages is not None: - if isinstance(source_languages, Language): - source_languages = [source_languages] - payload["input"]["languages"] = [lng.value["language"] for lng in source_languages] - - if target_languages is not None: - if isinstance(target_languages, Language): - target_languages = [target_languages] - payload["output"]["languages"] = [lng.value["language"] for lng in target_languages] - - logging.info(f"Start service for POST List Dataset - {url} - {headers} - {json.dumps(payload)}") - r = _request_with_retry("post", url, headers=headers, json=payload) - resp = r.json() - - datasets, page_total, total = [], 0, 0 - if "results" in resp: - results = resp["results"] - page_total = resp["pageTotal"] - total = resp["total"] - logging.info(f"Response for POST List Dataset - Page Total: {page_total} / Total: {total}") - for dataset in results: - datasets.append(cls.__from_response(dataset)) - return {"results": datasets, "page_total": page_total, "page_number": page_number, "total": total} + try: + url = urljoin(cls.backend_url, "sdk/datasets/paginate") + if cls.aixplain_key != "": + headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} + else: + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + + assert 0 < page_size <= 100, "Dataset List Error: Page size must be greater than 0 and not exceed 100." + payload = { + "pageSize": page_size, + "pageNumber": page_number, + "sort": [{"field": "createdAt", "dir": -1}], + "input": {}, + "output": {}, + } + + if query is not None: + payload["q"] = str(query) + + if function is not None: + payload["function"] = function.value + + if license is not None: + payload["license"] = license.value + + if data_type is not None: + payload["dataType"] = data_type.value + + if is_referenceless is not None: + payload["isReferenceless"] = is_referenceless + + if source_languages is not None: + if isinstance(source_languages, Language): + source_languages = [source_languages] + payload["input"]["languages"] = [lng.value["language"] for lng in source_languages] + + if target_languages is not None: + if isinstance(target_languages, Language): + target_languages = [target_languages] + payload["output"]["languages"] = [lng.value["language"] for lng in target_languages] + + logging.info(f"Start service for POST List Dataset - {url} - {headers} - {json.dumps(payload)}") + r = _request_with_retry("post", url, headers=headers, json=payload) + resp = r.json() + + if 200 <= r.status_code < 300: + datasets, page_total, total = [], 0, 0 + if "results" in resp: + results = resp["results"] + page_total = resp["pageTotal"] + total = resp["total"] + logging.info(f"Response for POST List Dataset - Page Total: {page_total} / Total: {total}") + for dataset in results: + datasets.append(cls.__from_response(dataset)) + return {"results": datasets, "page_total": page_total, "page_number": page_number, "total": total} + else: + error_message = f"Dataset List Error: Status {r.status_code} - {resp.get('message', 'No message')}" + logging.error(error_message) + raise Exception(error_message) + + except Exception as e: + error_message = f"Error listing datasets: {str(e)}" + logging.error(error_message, exc_info=True) + raise Exception(error_message) @classmethod def create( @@ -282,7 +300,7 @@ def create( error_handler: ErrorHandler = ErrorHandler.SKIP, s3_link: Optional[Text] = None, aws_credentials: Optional[Dict[Text, Text]] = {"AWS_ACCESS_KEY_ID": None, "AWS_SECRET_ACCESS_KEY": None}, - api_key: Optional[Text] = None + api_key: Optional[Text] = None, ) -> Dict: """Dataset Onboard diff --git a/aixplain/factories/model_factory.py b/aixplain/factories/model_factory.py index d82bdd63..b7b7ee42 100644 --- a/aixplain/factories/model_factory.py +++ b/aixplain/factories/model_factory.py @@ -113,13 +113,19 @@ def get(cls, model_id: Text, api_key: Optional[Text] = None) -> Model: logging.info(f"Start service for GET Model - {url} - {headers}") r = _request_with_retry("get", url, headers=headers) resp = r.json() - # set api key - resp["api_key"] = config.TEAM_API_KEY - if api_key is not None: - resp["api_key"] = api_key - model = cls._create_model_from_response(resp) - logging.info(f"Model Creation: Model {model_id} instantiated.") - return model + if 200 <= r.status_code < 300: + resp["api_key"] = config.TEAM_API_KEY + if api_key is not None: + resp["api_key"] = api_key + model = cls._create_model_from_response(resp) + logging.info(f"Model Creation: Model {model_id} instantiated.") + return model + else: + error_message = ( + f"Model GET Error: Failed to retrieve model {model_id}. Status Code: {r.status_code}. Error: {resp}" + ) + logging.error(error_message) + raise Exception(error_message) except Exception: if resp is not None and "statusCode" in resp: status_code = resp["statusCode"] @@ -198,10 +204,15 @@ def _get_assets_from_page( logging.info(f"Start service for POST Models Paginate - {url} - {headers} - {json.dumps(filter_params)}") r = _request_with_retry("post", url, headers=headers, json=filter_params) resp = r.json() - logging.info(f"Listing Models: Status of getting Models on Page {page_number}: {r.status_code}") - all_models = resp["items"] - model_list = [cls._create_model_from_response(model_info_json) for model_info_json in all_models] - return model_list, resp["total"] + if 200 <= r.status_code < 300: + logging.info(f"Listing Models: Status of getting Models on Page {page_number}: {r.status_code}") + all_models = resp["items"] + model_list = [cls._create_model_from_response(model_info_json) for model_info_json in all_models] + return model_list, resp["total"] + else: + error_message = f"Listing Models Error: Failed to retrieve models. Status Code: {r.status_code}. Error: {resp}" + logging.error(error_message) + raise Exception(error_message) except Exception as e: error_message = f"Listing Models: Error in getting Models on Page {page_number}: {e}" logging.error(error_message, exc_info=True) diff --git a/aixplain/factories/pipeline_factory/__init__.py b/aixplain/factories/pipeline_factory/__init__.py index cb4336fe..ba8ccad9 100644 --- a/aixplain/factories/pipeline_factory/__init__.py +++ b/aixplain/factories/pipeline_factory/__init__.py @@ -78,12 +78,19 @@ def get(cls, pipeline_id: Text, api_key: Optional[Text] = None) -> Pipeline: logging.info(f"Start service for GET Pipeline - {url} - {headers}") r = _request_with_retry("get", url, headers=headers) resp = r.json() - # set api key - resp["api_key"] = config.TEAM_API_KEY - if api_key is not None: - resp["api_key"] = api_key - pipeline = build_from_response(resp, load_architecture=True) - return pipeline + + if 200 <= r.status_code < 300: + resp["api_key"] = config.TEAM_API_KEY + if api_key is not None: + resp["api_key"] = api_key + pipeline = build_from_response(resp, load_architecture=True) + logging.info(f"Pipeline {pipeline_id} retrieved successfully.") + return pipeline + + else: + error_message = f"Pipeline GET Error: Failed to retrieve pipeline {pipeline_id}. Status Code: {r.status_code}. Error: {resp}" + logging.error(error_message) + raise Exception(error_message) except Exception as e: logging.exception(e) status_code = 400 @@ -220,23 +227,34 @@ def list( payload["inputDataTypes"] = [data_type.value for data_type in output_data_types] logging.info(f"Start service for POST List Pipeline - {url} - {headers} - {json.dumps(payload)}") - r = _request_with_retry("post", url, headers=headers, json=payload) - resp = r.json() - - pipelines, page_total, total = [], 0, 0 - if "items" in resp: - results = resp["items"] - page_total = resp["pageTotal"] - total = resp["total"] - logging.info(f"Response for POST List Pipeline - Page Total: {page_total} / Total: {total}") - for pipeline in results: - pipelines.append(build_from_response(pipeline)) - return { - "results": pipelines, - "page_total": page_total, - "page_number": page_number, - "total": total, - } + try: + r = _request_with_retry("post", url, headers=headers, json=payload) + resp = r.json() + if 200 <= r.status_code < 300: + pipelines, page_total, total = [], 0, 0 + if "items" in resp: + results = resp["items"] + page_total = resp["pageTotal"] + total = resp["total"] + logging.info(f"Response for POST List Pipeline - Page Total: {page_total} / Total: {total}") + for pipeline in results: + pipelines.append(build_from_response(pipeline)) + return { + "results": pipelines, + "page_total": page_total, + "page_number": page_number, + "total": total, + } + else: + error_message = ( + f"Pipeline List Error: Failed to retrieve pipelines. Status Code: {r.status_code}. Error: {resp}" + ) + logging.error(error_message) + raise Exception(error_message) + except Exception as e: + error_message = f"Pipeline List Error: {str(e)}" + logging.error(error_message, exc_info=True) + raise Exception(error_message) @classmethod def init(cls, name: Text, api_key: Optional[Text] = None) -> Pipeline: From 30a86c46dad2e601dcee0a8ac60810474b4c767e Mon Sep 17 00:00:00 2001 From: xainaz Date: Fri, 11 Oct 2024 18:55:55 +0300 Subject: [PATCH 2/4] Fixed issue + Added test --- aixplain/factories/benchmark_factory.py | 56 ++++++++------- aixplain/factories/corpus_factory.py | 59 ++++++++-------- aixplain/factories/dataset_factory.py | 44 ++++++------ aixplain/factories/model_factory.py | 44 ++++++------ .../factories/pipeline_factory/__init__.py | 67 +++++++++--------- tests/unit/benchmark_test.py | 70 +++++++++++++++++++ tests/unit/corpus_test.py | 34 +++++++++ tests/unit/dataset_test.py | 34 +++++++++ tests/unit/model_test.py | 42 +++++++++++ tests/unit/pipeline_test.py | 50 ++++++++++--- 10 files changed, 359 insertions(+), 141 deletions(-) create mode 100644 tests/unit/benchmark_test.py create mode 100644 tests/unit/corpus_test.py create mode 100644 tests/unit/dataset_test.py diff --git a/aixplain/factories/benchmark_factory.py b/aixplain/factories/benchmark_factory.py index ea983075..baf7ccb5 100644 --- a/aixplain/factories/benchmark_factory.py +++ b/aixplain/factories/benchmark_factory.py @@ -114,14 +114,7 @@ def get(cls, benchmark_id: str) -> Benchmark: logging.info(f"Start service for GET Benchmark - {url} - {headers}") r = _request_with_retry("get", url, headers=headers) resp = r.json() - if 200 <= r.status_code < 300: - benchmark = cls._create_benchmark_from_response(resp) - logging.info(f"Benchmark {benchmark_id} retrieved successfully.") - return benchmark - else: - error_message = f"Benchmark GET Error: Status {r.status_code} - {resp.get('message', 'No message')}" - logging.error(error_message) - raise Exception(error_message) + except Exception as e: status_code = 400 if resp is not None and "statusCode" in resp: @@ -132,6 +125,14 @@ def get(cls, benchmark_id: str) -> Benchmark: message = "Benchmark Creation: Unspecified Error" logging.error(f"Benchmark Creation Failed: {e}") raise Exception(f"Status {status_code}: {message}") + if 200 <= r.status_code < 300: + benchmark = cls._create_benchmark_from_response(resp) + logging.info(f"Benchmark {benchmark_id} retrieved successfully.") + return benchmark + else: + error_message = f"Benchmark GET Error: Status {r.status_code} - {resp.get('message', 'No message')}" + logging.error(error_message) + raise Exception(error_message) @classmethod def get_job(cls, job_id: Text) -> BenchmarkJob: @@ -207,17 +208,19 @@ def create(cls, name: str, dataset_list: List[Dataset], model_list: List[Model], payload = json.dumps(clean_payload) r = _request_with_retry("post", url, headers=headers, data=payload) resp = r.json() - if 200 <= r.status_code < 300: - logging.info(f"Benchmark {name} created successfully: {resp}") - return cls.get(resp["id"]) - else: - error_message = f"Benchmark Creation Error: Status {r.status_code} - {resp.get('message', 'No message')}" - logging.error(error_message) - raise Exception(error_message) + except Exception as e: error_message = f"Creating Benchmark Job: Error in Creating Benchmark with payload {payload} : {e}" logging.error(error_message, exc_info=True) - return None + raise Exception(error_message) + + if 200 <= r.status_code < 300: + logging.info(f"Benchmark {name} created successfully: {resp}") + return cls.get(resp["id"]) + else: + error_message = f"Benchmark Creation Error: Status {r.status_code} - {resp.get('message', 'No message')}" + logging.error(error_message) + raise Exception(error_message) @classmethod def list_normalization_options(cls, metric: Metric, model: Model) -> List[str]: @@ -240,20 +243,19 @@ def list_normalization_options(cls, metric: Metric, model: Model) -> List[str]: r = _request_with_retry("post", url, headers=headers, data=payload) resp = r.json() - if 200 <= r.status_code < 300: - logging.info(f"Listing Normalization Options: Status of listing options: {resp}") - normalization_options = [item["value"] for item in resp] - return normalization_options - else: - error_message = ( - f"Error listing normalization options: Status {r.status_code} - {resp.get('message', 'No message')}" - ) - logging.error(error_message) - return [] except Exception as e: error_message = f"Listing Normalization Options: Error in getting Normalization Options: {e}" logging.error(error_message, exc_info=True) - return [] + raise Exception(error_message) + + if 200 <= r.status_code < 300: + logging.info(f"Listing Normalization Options: Status of listing options: {resp}") + normalization_options = [item["value"] for item in resp] + return normalization_options + else: + error_message = f"Error listing normalization options: Status {r.status_code} - {resp.get('message', 'No message')}" + logging.error(error_message) + raise Exception(error_message) @classmethod def get_benchmark_job_scores(cls, job_id): diff --git a/aixplain/factories/corpus_factory.py b/aixplain/factories/corpus_factory.py index 59333ffd..1f9f78b0 100644 --- a/aixplain/factories/corpus_factory.py +++ b/aixplain/factories/corpus_factory.py @@ -124,17 +124,18 @@ def get(cls, corpus_id: Text) -> Corpus: logging.info(f"Start service for GET Corpus - {url} - {headers}") r = _request_with_retry("get", url, headers=headers) resp = r.json() - if 200 <= r.status_code < 300: - logging.info(f"Corpus {corpus_id} retrieved successfully.") - return cls.__from_response(resp) - else: - error_message = f"Corpus GET Error: Status {r.status_code} - {resp.get('message', 'No message')}" - logging.error(error_message) - raise Exception(error_message) + except Exception as e: error_message = f"Error retrieving Corpus {corpus_id}: {str(e)}" logging.error(error_message, exc_info=True) raise Exception(error_message) + if 200 <= r.status_code < 300: + logging.info(f"Corpus {corpus_id} retrieved successfully.") + return cls.__from_response(resp) + else: + error_message = f"Corpus GET Error: Status {r.status_code} - {resp.get('message', 'No message')}" + logging.error(error_message) + raise Exception(error_message) @classmethod def create_asset_from_id(cls, corpus_id: Text) -> Corpus: @@ -200,33 +201,33 @@ def list( logging.info(f"Start service for POST List Corpus - {url} - {headers} - {json.dumps(payload)}") r = _request_with_retry("post", url, headers=headers, json=payload) resp = r.json() - if 200 <= r.status_code < 300: - corpora, page_total, total = [], 0, 0 - if "results" in resp: - results = resp["results"] - page_total = resp["pageTotal"] - total = resp["total"] - logging.info(f"Response for POST List Corpus - Page Total: {page_total} / Total: {total}") - for corpus in results: - corpus_ = cls.__from_response(corpus) - # add languages - languages = [] - for lng in corpus["languages"]: - if "dialect" not in lng: - lng["dialect"] = "" - languages.append(Language(lng)) - corpus_.kwargs["languages"] = languages - corpora.append(corpus_) - return {"results": corpora, "page_total": page_total, "page_number": page_number, "total": total} - else: - error_message = f"Corpus List Error: Status {r.status_code} - {resp.get('message', 'No message')}" - logging.error(error_message) - raise Exception(error_message) except Exception as e: error_message = f"Error listing corpora: {str(e)}" logging.error(error_message, exc_info=True) raise Exception(error_message) + if 200 <= r.status_code < 300: + corpora, page_total, total = [], 0, 0 + if "results" in resp: + results = resp["results"] + page_total = resp["pageTotal"] + total = resp["total"] + logging.info(f"Response for POST List Corpus - Page Total: {page_total} / Total: {total}") + for corpus in results: + corpus_ = cls.__from_response(corpus) + # add languages + languages = [] + for lng in corpus["languages"]: + if "dialect" not in lng: + lng["dialect"] = "" + languages.append(Language(lng)) + corpus_.kwargs["languages"] = languages + corpora.append(corpus_) + return {"results": corpora, "page_total": page_total, "page_number": page_number, "total": total} + else: + error_message = f"Corpus List Error: Status {r.status_code} - {resp.get('message', 'No message')}" + logging.error(error_message) + raise Exception(error_message) @classmethod def get_assets_from_page( diff --git a/aixplain/factories/dataset_factory.py b/aixplain/factories/dataset_factory.py index 4b486cf0..67afecad 100644 --- a/aixplain/factories/dataset_factory.py +++ b/aixplain/factories/dataset_factory.py @@ -171,17 +171,18 @@ def get(cls, dataset_id: Text) -> Dataset: logging.info(f"Start service for GET Dataset - {url} - {headers}") r = _request_with_retry("get", url, headers=headers) resp = r.json() - if 200 <= r.status_code < 300: - logging.info(f"Dataset {dataset_id} retrieved successfully.") - return cls.__from_response(resp) - else: - error_message = f"Dataset GET Error: Status {r.status_code} - {resp.get('message', 'No message')}" - logging.error(error_message) - raise Exception(error_message) + except Exception as e: error_message = f"Error retrieving Dataset {dataset_id}: {str(e)}" logging.error(error_message, exc_info=True) raise Exception(error_message) + if 200 <= r.status_code < 300: + logging.info(f"Dataset {dataset_id} retrieved successfully.") + return cls.__from_response(resp) + else: + error_message = f"Dataset GET Error: Status {r.status_code} - {resp.get('message', 'No message')}" + logging.error(error_message) + raise Exception(error_message) @classmethod def list( @@ -257,25 +258,24 @@ def list( r = _request_with_retry("post", url, headers=headers, json=payload) resp = r.json() - if 200 <= r.status_code < 300: - datasets, page_total, total = [], 0, 0 - if "results" in resp: - results = resp["results"] - page_total = resp["pageTotal"] - total = resp["total"] - logging.info(f"Response for POST List Dataset - Page Total: {page_total} / Total: {total}") - for dataset in results: - datasets.append(cls.__from_response(dataset)) - return {"results": datasets, "page_total": page_total, "page_number": page_number, "total": total} - else: - error_message = f"Dataset List Error: Status {r.status_code} - {resp.get('message', 'No message')}" - logging.error(error_message) - raise Exception(error_message) - except Exception as e: error_message = f"Error listing datasets: {str(e)}" logging.error(error_message, exc_info=True) raise Exception(error_message) + if 200 <= r.status_code < 300: + datasets, page_total, total = [], 0, 0 + if "results" in resp: + results = resp["results"] + page_total = resp["pageTotal"] + total = resp["total"] + logging.info(f"Response for POST List Dataset - Page Total: {page_total} / Total: {total}") + for dataset in results: + datasets.append(cls.__from_response(dataset)) + return {"results": datasets, "page_total": page_total, "page_number": page_number, "total": total} + else: + error_message = f"Dataset List Error: Status {r.status_code} - {resp.get('message', 'No message')}" + logging.error(error_message) + raise Exception(error_message) @classmethod def create( diff --git a/aixplain/factories/model_factory.py b/aixplain/factories/model_factory.py index b7b7ee42..8f544882 100644 --- a/aixplain/factories/model_factory.py +++ b/aixplain/factories/model_factory.py @@ -113,19 +113,7 @@ def get(cls, model_id: Text, api_key: Optional[Text] = None) -> Model: logging.info(f"Start service for GET Model - {url} - {headers}") r = _request_with_retry("get", url, headers=headers) resp = r.json() - if 200 <= r.status_code < 300: - resp["api_key"] = config.TEAM_API_KEY - if api_key is not None: - resp["api_key"] = api_key - model = cls._create_model_from_response(resp) - logging.info(f"Model Creation: Model {model_id} instantiated.") - return model - else: - error_message = ( - f"Model GET Error: Failed to retrieve model {model_id}. Status Code: {r.status_code}. Error: {resp}" - ) - logging.error(error_message) - raise Exception(error_message) + except Exception: if resp is not None and "statusCode" in resp: status_code = resp["statusCode"] @@ -135,6 +123,17 @@ def get(cls, model_id: Text, api_key: Optional[Text] = None) -> Model: message = "Model Creation: Unspecified Error" logging.error(message) raise Exception(f"{message}") + if 200 <= r.status_code < 300: + resp["api_key"] = config.TEAM_API_KEY + if api_key is not None: + resp["api_key"] = api_key + model = cls._create_model_from_response(resp) + logging.info(f"Model Creation: Model {model_id} instantiated.") + return model + else: + error_message = f"Model GET Error: Failed to retrieve model {model_id}. Status Code: {r.status_code}. Error: {resp}" + logging.error(error_message) + raise Exception(error_message) @classmethod def create_asset_from_id(cls, model_id: Text) -> Model: @@ -204,19 +203,20 @@ def _get_assets_from_page( logging.info(f"Start service for POST Models Paginate - {url} - {headers} - {json.dumps(filter_params)}") r = _request_with_retry("post", url, headers=headers, json=filter_params) resp = r.json() - if 200 <= r.status_code < 300: - logging.info(f"Listing Models: Status of getting Models on Page {page_number}: {r.status_code}") - all_models = resp["items"] - model_list = [cls._create_model_from_response(model_info_json) for model_info_json in all_models] - return model_list, resp["total"] - else: - error_message = f"Listing Models Error: Failed to retrieve models. Status Code: {r.status_code}. Error: {resp}" - logging.error(error_message) - raise Exception(error_message) + except Exception as e: error_message = f"Listing Models: Error in getting Models on Page {page_number}: {e}" logging.error(error_message, exc_info=True) return [] + if 200 <= r.status_code < 300: + logging.info(f"Listing Models: Status of getting Models on Page {page_number}: {r.status_code}") + all_models = resp["items"] + model_list = [cls._create_model_from_response(model_info_json) for model_info_json in all_models] + return model_list, resp["total"] + else: + error_message = f"Listing Models Error: Failed to retrieve models. Status Code: {r.status_code}. Error: {resp}" + logging.error(error_message) + raise Exception(error_message) @classmethod def list( diff --git a/aixplain/factories/pipeline_factory/__init__.py b/aixplain/factories/pipeline_factory/__init__.py index ba8ccad9..ef330de0 100644 --- a/aixplain/factories/pipeline_factory/__init__.py +++ b/aixplain/factories/pipeline_factory/__init__.py @@ -79,18 +79,6 @@ def get(cls, pipeline_id: Text, api_key: Optional[Text] = None) -> Pipeline: r = _request_with_retry("get", url, headers=headers) resp = r.json() - if 200 <= r.status_code < 300: - resp["api_key"] = config.TEAM_API_KEY - if api_key is not None: - resp["api_key"] = api_key - pipeline = build_from_response(resp, load_architecture=True) - logging.info(f"Pipeline {pipeline_id} retrieved successfully.") - return pipeline - - else: - error_message = f"Pipeline GET Error: Failed to retrieve pipeline {pipeline_id}. Status Code: {r.status_code}. Error: {resp}" - logging.error(error_message) - raise Exception(error_message) except Exception as e: logging.exception(e) status_code = 400 @@ -102,6 +90,20 @@ def get(cls, pipeline_id: Text, api_key: Optional[Text] = None) -> Pipeline: message = f"Pipeline Creation: Unspecified Error {e}" logging.error(message) raise Exception(f"Status {status_code}: {message}") + if 200 <= r.status_code < 300: + resp["api_key"] = config.TEAM_API_KEY + if api_key is not None: + resp["api_key"] = api_key + pipeline = build_from_response(resp, load_architecture=True) + logging.info(f"Pipeline {pipeline_id} retrieved successfully.") + return pipeline + + else: + error_message = ( + f"Pipeline GET Error: Failed to retrieve pipeline {pipeline_id}. Status Code: {r.status_code}. Error: {resp}" + ) + logging.error(error_message) + raise Exception(error_message) @classmethod def create_asset_from_id(cls, pipeline_id: Text) -> Pipeline: @@ -230,31 +232,30 @@ def list( try: r = _request_with_retry("post", url, headers=headers, json=payload) resp = r.json() - if 200 <= r.status_code < 300: - pipelines, page_total, total = [], 0, 0 - if "items" in resp: - results = resp["items"] - page_total = resp["pageTotal"] - total = resp["total"] - logging.info(f"Response for POST List Pipeline - Page Total: {page_total} / Total: {total}") - for pipeline in results: - pipelines.append(build_from_response(pipeline)) - return { - "results": pipelines, - "page_total": page_total, - "page_number": page_number, - "total": total, - } - else: - error_message = ( - f"Pipeline List Error: Failed to retrieve pipelines. Status Code: {r.status_code}. Error: {resp}" - ) - logging.error(error_message) - raise Exception(error_message) + except Exception as e: error_message = f"Pipeline List Error: {str(e)}" logging.error(error_message, exc_info=True) raise Exception(error_message) + if 200 <= r.status_code < 300: + pipelines, page_total, total = [], 0, 0 + if "items" in resp: + results = resp["items"] + page_total = resp["pageTotal"] + total = resp["total"] + logging.info(f"Response for POST List Pipeline - Page Total: {page_total} / Total: {total}") + for pipeline in results: + pipelines.append(build_from_response(pipeline)) + return { + "results": pipelines, + "page_total": page_total, + "page_number": page_number, + "total": total, + } + else: + error_message = f"Pipeline List Error: Failed to retrieve pipelines. Status Code: {r.status_code}. Error: {resp}" + logging.error(error_message) + raise Exception(error_message) @classmethod def init(cls, name: Text, api_key: Optional[Text] = None) -> Pipeline: diff --git a/tests/unit/benchmark_test.py b/tests/unit/benchmark_test.py new file mode 100644 index 00000000..9dd18f53 --- /dev/null +++ b/tests/unit/benchmark_test.py @@ -0,0 +1,70 @@ +import requests_mock +import pytest +from urllib.parse import urljoin +from aixplain.utils import config +from aixplain.factories import MetricFactory, BenchmarkFactory +from aixplain.modules.model import Model +from aixplain.modules.dataset import Dataset + + +def test_create_benchmark_error_response(): + metric_list = [MetricFactory.get("66df3e2d6eb56336b6628171")] + with requests_mock.Mocker() as mock: + name = "test-benchmark" + dataset_list = [ + Dataset( + id="dataset1", + name="Dataset 1", + description="Test dataset", + function="test_func", + source_data="src", + target_data="tgt", + onboard_status="onboarded", + ) + ] + model_list = [ + Model(id="model1", name="Model 1", description="Test model", supplier="Test supplier", cost=10, version="v1") + ] + + url = urljoin(config.BACKEND_URL, "sdk/benchmarks") + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + + error_response = {"statusCode": 400, "message": "Invalid request"} + mock.post(url, headers=headers, json=error_response, status_code=400) + + with pytest.raises(Exception) as excinfo: + BenchmarkFactory.create(name=name, dataset_list=dataset_list, model_list=model_list, metric_list=metric_list) + + assert "Benchmark Creation Error: Status 400 - Invalid request" in str(excinfo.value) + + +def test_get_benchmark_error(): + with requests_mock.Mocker() as mock: + benchmark_id = "test-benchmark-id" + url = urljoin(config.BACKEND_URL, f"sdk/benchmarks/{benchmark_id}") + headers = {"x-aixplain-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + + error_response = {"statusCode": 404, "message": "Benchmark not found"} + mock.get(url, headers=headers, json=error_response, status_code=404) + + with pytest.raises(Exception) as excinfo: + BenchmarkFactory.get(benchmark_id) + + assert "Benchmark GET Error: Status 404 - Benchmark not found" in str(excinfo.value) + + +def test_list_normalization_options_error(): + metric = MetricFactory.get("66df3e2d6eb56336b6628171") + with requests_mock.Mocker() as mock: + model = Model(id="model1", name="Test Model", description="Test model", supplier="Test supplier", cost=10, version="v1") + + url = urljoin(config.BACKEND_URL, "sdk/benchmarks/normalization-options") + headers = {"x-aixplain-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + + error_response = {"message": "Internal Server Error"} + mock.post(url, headers=headers, json=error_response, status_code=500) + + with pytest.raises(Exception) as excinfo: + BenchmarkFactory.list_normalization_options(metric, model) + + assert "Error listing normalization options: Status 500 - Internal Server Error" in str(excinfo.value) diff --git a/tests/unit/corpus_test.py b/tests/unit/corpus_test.py new file mode 100644 index 00000000..9d2ef254 --- /dev/null +++ b/tests/unit/corpus_test.py @@ -0,0 +1,34 @@ +from aixplain.factories import CorpusFactory +import pytest +import requests_mock +from urllib.parse import urljoin +from aixplain.utils import config + + +def test_get_corpus_error_response(): + with requests_mock.Mocker() as mock: + corpus_id = "invalid_corpus_id" + url = urljoin(config.BACKEND_URL, f"sdk/corpora/{corpus_id}/overview") + headers = {"x-aixplain-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + + error_response = {"message": "Not Found"} + mock.get(url, headers=headers, json=error_response, status_code=404) + + with pytest.raises(Exception) as excinfo: + CorpusFactory.get(corpus_id=corpus_id) + + assert "Corpus GET Error: Status 404 - Not Found" in str(excinfo.value) + + +def test_list_corpus_error_response(): + with requests_mock.Mocker() as mock: + url = urljoin(config.BACKEND_URL, "sdk/corpora/paginate") + headers = {"x-aixplain-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + + error_response = {"message": "Internal Server Error"} + mock.post(url, headers=headers, json=error_response, status_code=500) + + with pytest.raises(Exception) as excinfo: + CorpusFactory.list(query="test_query", page_number=0, page_size=20) + + assert "Corpus List Error: Status 500 - Internal Server Error" in str(excinfo.value) diff --git a/tests/unit/dataset_test.py b/tests/unit/dataset_test.py new file mode 100644 index 00000000..4332334c --- /dev/null +++ b/tests/unit/dataset_test.py @@ -0,0 +1,34 @@ +import pytest +import requests_mock +from aixplain.factories import DatasetFactory +from urllib.parse import urljoin +from aixplain.utils import config + + +def test_list_dataset_error_response(): + with requests_mock.Mocker() as mock: + url = urljoin(config.BACKEND_URL, "sdk/datasets/paginate") + headers = {"x-aixplain-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + + error_response = {"message": "Internal Server Error"} + mock.post(url, headers=headers, json=error_response, status_code=500) + + with pytest.raises(Exception) as excinfo: + DatasetFactory.list(query="test_query", page_number=0, page_size=20) + + assert "Dataset List Error: Status 500 - Internal Server Error" in str(excinfo.value) + + +def test_get_dataset_error_response(): + with requests_mock.Mocker() as mock: + dataset_id = "invalid_dataset_id" + url = urljoin(config.BACKEND_URL, f"sdk/datasets/{dataset_id}/overview") + headers = {"x-aixplain-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + + error_response = {"message": "Not Found"} + mock.get(url, headers=headers, json=error_response, status_code=404) + + with pytest.raises(Exception) as excinfo: + DatasetFactory.get(dataset_id=dataset_id) + + assert "Dataset GET Error: Status 404 - Not Found" in str(excinfo.value) diff --git a/tests/unit/model_test.py b/tests/unit/model_test.py index c52bb950..a319742c 100644 --- a/tests/unit/model_test.py +++ b/tests/unit/model_test.py @@ -24,6 +24,8 @@ import re from aixplain.utils import config from aixplain.modules import Model +from aixplain.factories import ModelFactory +from aixplain.enums import Function import pytest @@ -82,3 +84,43 @@ def test_run_async_errors(status_code, error_message): response = test_model.run_async(data="input_data") assert response["status"] == "FAILED" assert response["error_message"] == error_message + + +def test_get_model_error_response(): + with requests_mock.Mocker() as mock: + model_id = "test-model-id" + url = urljoin(config.BACKEND_URL, f"sdk/models/{model_id}") + headers = {"x-aixplain-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + + error_response = {"statusCode": 404, "message": "Model not found"} + mock.get(url, headers=headers, json=error_response, status_code=404) + + with pytest.raises(Exception) as excinfo: + ModelFactory.get(model_id) + + assert "Model GET Error: Failed to retrieve model test-model-id" in str(excinfo.value) + + +def test_get_assets_from_page_error(): + with requests_mock.Mocker() as mock: + query = "test-query" + page_number = 0 + page_size = 2 + url = urljoin(config.BACKEND_URL, "sdk/models/paginate") + headers = {"x-aixplain-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + + error_response = {"statusCode": 500, "message": "Internal Server Error"} + mock.post(url, headers=headers, json=error_response, status_code=500) + + with pytest.raises(Exception) as excinfo: + ModelFactory._get_assets_from_page( + query=query, + page_number=page_number, + page_size=page_size, + function=Function.TEXT_GENERATION, + suppliers=None, + source_languages=None, + target_languages=None, + ) + + assert "Listing Models Error: Failed to retrieve models" in str(excinfo.value) diff --git a/tests/unit/pipeline_test.py b/tests/unit/pipeline_test.py index d3c1c725..05ee7172 100644 --- a/tests/unit/pipeline_test.py +++ b/tests/unit/pipeline_test.py @@ -38,27 +38,61 @@ def test_create_pipeline(): assert hyp_pipeline.id == ref_pipeline.id assert hyp_pipeline.name == ref_pipeline.name + @pytest.mark.parametrize( "status_code,error_message", [ - (401,"Unauthorized API key: Please verify the spelling of the API key and its current validity."), - (465,"Subscription-related error: Please ensure that your subscription is active and has not expired."), - (475,"Billing-related error: Please ensure you have enough credits to run this pipeline. "), - (485, "Supplier-related error: Please ensure that the selected supplier provides the pipeline you are trying to access."), + (401, "Unauthorized API key: Please verify the spelling of the API key and its current validity."), + (465, "Subscription-related error: Please ensure that your subscription is active and has not expired."), + (475, "Billing-related error: Please ensure you have enough credits to run this pipeline. "), + ( + 485, + "Supplier-related error: Please ensure that the selected supplier provides the pipeline you are trying to access.", + ), (495, "Validation-related error: Please ensure all required fields are provided and correctly formatted."), (501, "Status 501: Unspecified error: An unspecified error occurred while processing your request."), - ], ) - def test_run_async_errors(status_code, error_message): base_url = config.BACKEND_URL pipeline_id = "pipeline_id" execute_url = f"{base_url}/assets/pipeline/execution/run/{pipeline_id}" - + with requests_mock.Mocker() as mock: mock.post(execute_url, status_code=status_code) test_pipeline = Pipeline(id=pipeline_id, api_key=config.TEAM_API_KEY, name="Test Pipeline", url=base_url) response = test_pipeline.run_async(data="input_data") assert response["status"] == "FAILED" - assert response["error_message"] == error_message \ No newline at end of file + assert response["error_message"] == error_message + + +def test_list_pipelines_error_response(): + with requests_mock.Mocker() as mock: + query = "test-query" + page_number = 0 + page_size = 20 + url = urljoin(config.BACKEND_URL, "sdk/pipelines/paginate") + headers = {"x-aixplain-key": config.AIXPLAIN_API_KEY, "Content-Type": "application/json"} + + error_response = {"statusCode": 400, "message": "Bad Request"} + mock.post(url, headers=headers, json=error_response, status_code=400) + + with pytest.raises(Exception) as excinfo: + PipelineFactory.list(query=query, page_number=page_number, page_size=page_size) + + assert "Pipeline List Error: Failed to retrieve pipelines. Status Code: 400" in str(excinfo.value) + + +def test_get_pipeline_error_response(): + with requests_mock.Mocker() as mock: + pipeline_id = "test-pipeline-id" + url = urljoin(config.BACKEND_URL, f"sdk/pipelines/{pipeline_id}") + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + + error_response = {"statusCode": 404, "message": "Pipeline not found"} + mock.get(url, headers=headers, json=error_response, status_code=404) + + with pytest.raises(Exception) as excinfo: + PipelineFactory.get(pipeline_id=pipeline_id) + + assert "Pipeline GET Error: Failed to retrieve pipeline test-pipeline-id. Status Code: 404" in str(excinfo.value) From c0585478a7b88e460bc88696af1b631ccb0b88dd Mon Sep 17 00:00:00 2001 From: xainaz Date: Mon, 14 Oct 2024 22:05:39 +0300 Subject: [PATCH 3/4] Added required changes to error prompts --- aixplain/factories/benchmark_factory.py | 10 +++++----- aixplain/factories/corpus_factory.py | 4 ++-- aixplain/factories/dataset_factory.py | 4 ++-- tests/unit/benchmark_test.py | 6 +++--- tests/unit/corpus_test.py | 4 ++-- tests/unit/dataset_test.py | 4 ++-- 6 files changed, 16 insertions(+), 16 deletions(-) diff --git a/aixplain/factories/benchmark_factory.py b/aixplain/factories/benchmark_factory.py index baf7ccb5..305fb5d9 100644 --- a/aixplain/factories/benchmark_factory.py +++ b/aixplain/factories/benchmark_factory.py @@ -130,7 +130,7 @@ def get(cls, benchmark_id: str) -> Benchmark: logging.info(f"Benchmark {benchmark_id} retrieved successfully.") return benchmark else: - error_message = f"Benchmark GET Error: Status {r.status_code} - {resp.get('message', 'No message')}" + error_message = f"Benchmark GET Error: Status {r.status_code} - {resp}" logging.error(error_message) raise Exception(error_message) @@ -215,10 +215,10 @@ def create(cls, name: str, dataset_list: List[Dataset], model_list: List[Model], raise Exception(error_message) if 200 <= r.status_code < 300: - logging.info(f"Benchmark {name} created successfully: {resp}") + logging.info(f"Benchmark {name} created successfully.") return cls.get(resp["id"]) else: - error_message = f"Benchmark Creation Error: Status {r.status_code} - {resp.get('message', 'No message')}" + error_message = f"Benchmark Creation Error: Status {r.status_code} - {resp}" logging.error(error_message) raise Exception(error_message) @@ -249,11 +249,11 @@ def list_normalization_options(cls, metric: Metric, model: Model) -> List[str]: raise Exception(error_message) if 200 <= r.status_code < 300: - logging.info(f"Listing Normalization Options: Status of listing options: {resp}") + logging.info("Listing Normalization Options: ") normalization_options = [item["value"] for item in resp] return normalization_options else: - error_message = f"Error listing normalization options: Status {r.status_code} - {resp.get('message', 'No message')}" + error_message = f"Error listing normalization options: Status {r.status_code} - {resp}" logging.error(error_message) raise Exception(error_message) diff --git a/aixplain/factories/corpus_factory.py b/aixplain/factories/corpus_factory.py index 1f9f78b0..c971c280 100644 --- a/aixplain/factories/corpus_factory.py +++ b/aixplain/factories/corpus_factory.py @@ -133,7 +133,7 @@ def get(cls, corpus_id: Text) -> Corpus: logging.info(f"Corpus {corpus_id} retrieved successfully.") return cls.__from_response(resp) else: - error_message = f"Corpus GET Error: Status {r.status_code} - {resp.get('message', 'No message')}" + error_message = f"Corpus GET Error: Status {r.status_code} - {resp}" logging.error(error_message) raise Exception(error_message) @@ -225,7 +225,7 @@ def list( corpora.append(corpus_) return {"results": corpora, "page_total": page_total, "page_number": page_number, "total": total} else: - error_message = f"Corpus List Error: Status {r.status_code} - {resp.get('message', 'No message')}" + error_message = f"Corpus List Error: Status {r.status_code} - {resp}" logging.error(error_message) raise Exception(error_message) diff --git a/aixplain/factories/dataset_factory.py b/aixplain/factories/dataset_factory.py index 67afecad..1d2da882 100644 --- a/aixplain/factories/dataset_factory.py +++ b/aixplain/factories/dataset_factory.py @@ -180,7 +180,7 @@ def get(cls, dataset_id: Text) -> Dataset: logging.info(f"Dataset {dataset_id} retrieved successfully.") return cls.__from_response(resp) else: - error_message = f"Dataset GET Error: Status {r.status_code} - {resp.get('message', 'No message')}" + error_message = f"Dataset GET Error: Status {r.status_code} - {resp}" logging.error(error_message) raise Exception(error_message) @@ -273,7 +273,7 @@ def list( datasets.append(cls.__from_response(dataset)) return {"results": datasets, "page_total": page_total, "page_number": page_number, "total": total} else: - error_message = f"Dataset List Error: Status {r.status_code} - {resp.get('message', 'No message')}" + error_message = f"Dataset List Error: Status {r.status_code} - {resp}" logging.error(error_message) raise Exception(error_message) diff --git a/tests/unit/benchmark_test.py b/tests/unit/benchmark_test.py index 9dd18f53..167e4bcb 100644 --- a/tests/unit/benchmark_test.py +++ b/tests/unit/benchmark_test.py @@ -35,7 +35,7 @@ def test_create_benchmark_error_response(): with pytest.raises(Exception) as excinfo: BenchmarkFactory.create(name=name, dataset_list=dataset_list, model_list=model_list, metric_list=metric_list) - assert "Benchmark Creation Error: Status 400 - Invalid request" in str(excinfo.value) + assert "Benchmark Creation Error: Status 400 - {'statusCode': 400, 'message': 'Invalid request'}" in str(excinfo.value) def test_get_benchmark_error(): @@ -50,7 +50,7 @@ def test_get_benchmark_error(): with pytest.raises(Exception) as excinfo: BenchmarkFactory.get(benchmark_id) - assert "Benchmark GET Error: Status 404 - Benchmark not found" in str(excinfo.value) + assert "Benchmark GET Error: Status 404 - {'statusCode': 404, 'message': 'Benchmark not found'}" in str(excinfo.value) def test_list_normalization_options_error(): @@ -67,4 +67,4 @@ def test_list_normalization_options_error(): with pytest.raises(Exception) as excinfo: BenchmarkFactory.list_normalization_options(metric, model) - assert "Error listing normalization options: Status 500 - Internal Server Error" in str(excinfo.value) + assert "Error listing normalization options: Status 500 - {'message': 'Internal Server Error'}" in str(excinfo.value) diff --git a/tests/unit/corpus_test.py b/tests/unit/corpus_test.py index 9d2ef254..07522c4d 100644 --- a/tests/unit/corpus_test.py +++ b/tests/unit/corpus_test.py @@ -17,7 +17,7 @@ def test_get_corpus_error_response(): with pytest.raises(Exception) as excinfo: CorpusFactory.get(corpus_id=corpus_id) - assert "Corpus GET Error: Status 404 - Not Found" in str(excinfo.value) + assert "Corpus GET Error: Status 404 - {'message': 'Not Found'}" in str(excinfo.value) def test_list_corpus_error_response(): @@ -31,4 +31,4 @@ def test_list_corpus_error_response(): with pytest.raises(Exception) as excinfo: CorpusFactory.list(query="test_query", page_number=0, page_size=20) - assert "Corpus List Error: Status 500 - Internal Server Error" in str(excinfo.value) + assert "Corpus List Error: Status 500 - {'message': 'Internal Server Error'}" in str(excinfo.value) diff --git a/tests/unit/dataset_test.py b/tests/unit/dataset_test.py index 4332334c..25c57123 100644 --- a/tests/unit/dataset_test.py +++ b/tests/unit/dataset_test.py @@ -16,7 +16,7 @@ def test_list_dataset_error_response(): with pytest.raises(Exception) as excinfo: DatasetFactory.list(query="test_query", page_number=0, page_size=20) - assert "Dataset List Error: Status 500 - Internal Server Error" in str(excinfo.value) + assert "Dataset List Error: Status 500 - {'message': 'Internal Server Error'}" in str(excinfo.value) def test_get_dataset_error_response(): @@ -31,4 +31,4 @@ def test_get_dataset_error_response(): with pytest.raises(Exception) as excinfo: DatasetFactory.get(dataset_id=dataset_id) - assert "Dataset GET Error: Status 404 - Not Found" in str(excinfo.value) + assert "Dataset GET Error: Status 404 - {'message': 'Not Found'}" in str(excinfo.value) From 6a720f3b6ca0584a44096aa145a622ace810f302 Mon Sep 17 00:00:00 2001 From: Thiago Castro Ferreira Date: Tue, 15 Oct 2024 17:56:40 -0300 Subject: [PATCH 4/4] Small improvements --- aixplain/factories/corpus_factory.py | 41 ++++++++--------- aixplain/factories/dataset_factory.py | 64 +++++++++++++-------------- aixplain/factories/model_factory.py | 43 ++++++++---------- 3 files changed, 72 insertions(+), 76 deletions(-) diff --git a/aixplain/factories/corpus_factory.py b/aixplain/factories/corpus_factory.py index c971c280..3b9c5e4b 100644 --- a/aixplain/factories/corpus_factory.py +++ b/aixplain/factories/corpus_factory.py @@ -171,33 +171,33 @@ def list( Returns: Dict: list of corpora in agreement with the filters, page number, page total and total elements """ - try: - url = urljoin(cls.backend_url, "sdk/corpora/paginate") - if cls.aixplain_key != "": - headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} - else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + url = urljoin(cls.backend_url, "sdk/corpora/paginate") + if cls.aixplain_key != "": + headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} + else: + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} - assert 0 < page_size <= 100, "Corpus List Error: Page size must be greater than 0 and not exceed 100." - payload = {"pageSize": page_size, "pageNumber": page_number, "sort": [{"field": "createdAt", "dir": -1}]} + assert 0 < page_size <= 100, "Corpus List Error: Page size must be greater than 0 and not exceed 100." + payload = {"pageSize": page_size, "pageNumber": page_number, "sort": [{"field": "createdAt", "dir": -1}]} - if query is not None: - payload["q"] = str(query) + if query is not None: + payload["q"] = str(query) - if function is not None: - payload["function"] = function.value + if function is not None: + payload["function"] = function.value - if license is not None: - payload["license"] = license.value + if license is not None: + payload["license"] = license.value - if data_type is not None: - payload["dataType"] = data_type.value + if data_type is not None: + payload["dataType"] = data_type.value - if language is not None: - if isinstance(language, Language): - language = [language] - payload["language"] = [lng.value["language"] for lng in language] + if language is not None: + if isinstance(language, Language): + language = [language] + payload["language"] = [lng.value["language"] for lng in language] + try: logging.info(f"Start service for POST List Corpus - {url} - {headers} - {json.dumps(payload)}") r = _request_with_retry("post", url, headers=headers, json=payload) resp = r.json() @@ -206,6 +206,7 @@ def list( error_message = f"Error listing corpora: {str(e)}" logging.error(error_message, exc_info=True) raise Exception(error_message) + if 200 <= r.status_code < 300: corpora, page_total, total = [], 0, 0 if "results" in resp: diff --git a/aixplain/factories/dataset_factory.py b/aixplain/factories/dataset_factory.py index 1d2da882..081513c0 100644 --- a/aixplain/factories/dataset_factory.py +++ b/aixplain/factories/dataset_factory.py @@ -213,47 +213,47 @@ def list( Returns: Dict: list of datasets in agreement with the filters, page number, page total and total elements """ - try: - url = urljoin(cls.backend_url, "sdk/datasets/paginate") - if cls.aixplain_key != "": - headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} - else: - headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} + url = urljoin(cls.backend_url, "sdk/datasets/paginate") + if cls.aixplain_key != "": + headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"} + else: + headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"} - assert 0 < page_size <= 100, "Dataset List Error: Page size must be greater than 0 and not exceed 100." - payload = { - "pageSize": page_size, - "pageNumber": page_number, - "sort": [{"field": "createdAt", "dir": -1}], - "input": {}, - "output": {}, - } + assert 0 < page_size <= 100, "Dataset List Error: Page size must be greater than 0 and not exceed 100." + payload = { + "pageSize": page_size, + "pageNumber": page_number, + "sort": [{"field": "createdAt", "dir": -1}], + "input": {}, + "output": {}, + } - if query is not None: - payload["q"] = str(query) + if query is not None: + payload["q"] = str(query) - if function is not None: - payload["function"] = function.value + if function is not None: + payload["function"] = function.value - if license is not None: - payload["license"] = license.value + if license is not None: + payload["license"] = license.value - if data_type is not None: - payload["dataType"] = data_type.value + if data_type is not None: + payload["dataType"] = data_type.value - if is_referenceless is not None: - payload["isReferenceless"] = is_referenceless + if is_referenceless is not None: + payload["isReferenceless"] = is_referenceless - if source_languages is not None: - if isinstance(source_languages, Language): - source_languages = [source_languages] - payload["input"]["languages"] = [lng.value["language"] for lng in source_languages] + if source_languages is not None: + if isinstance(source_languages, Language): + source_languages = [source_languages] + payload["input"]["languages"] = [lng.value["language"] for lng in source_languages] - if target_languages is not None: - if isinstance(target_languages, Language): - target_languages = [target_languages] - payload["output"]["languages"] = [lng.value["language"] for lng in target_languages] + if target_languages is not None: + if isinstance(target_languages, Language): + target_languages = [target_languages] + payload["output"]["languages"] = [lng.value["language"] for lng in target_languages] + try: logging.info(f"Start service for POST List Dataset - {url} - {headers} - {json.dumps(payload)}") r = _request_with_retry("post", url, headers=headers, json=payload) resp = r.json() diff --git a/aixplain/factories/model_factory.py b/aixplain/factories/model_factory.py index 8f544882..5df7c924 100644 --- a/aixplain/factories/model_factory.py +++ b/aixplain/factories/model_factory.py @@ -248,30 +248,25 @@ def list( Returns: List[Model]: List of models based on given filters """ - try: - models, total = cls._get_assets_from_page( - query, - page_number, - page_size, - function, - suppliers, - source_languages, - target_languages, - is_finetunable, - ownership, - sort_by, - sort_order, - ) - return { - "results": models, - "page_total": min(page_size, len(models)), - "page_number": page_number, - "total": total, - } - except Exception as e: - error_message = f"Listing Models: Error in Listing Models : {e}" - logging.error(error_message, exc_info=True) - raise Exception(error_message) + models, total = cls._get_assets_from_page( + query, + page_number, + page_size, + function, + suppliers, + source_languages, + target_languages, + is_finetunable, + ownership, + sort_by, + sort_order, + ) + return { + "results": models, + "page_total": min(page_size, len(models)), + "page_number": page_number, + "total": total, + } @classmethod def list_host_machines(cls, api_key: Optional[Text] = None) -> List[Dict]: