Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 37 additions & 18 deletions aixplain/factories/benchmark_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@
"""

import logging
from typing import Dict, List, Optional, Text
from typing import Dict, List, Text
import json
import pandas as pd
from pathlib import Path
from aixplain.enums.supplier import Supplier
from aixplain.modules import Dataset, Metric, Model
from aixplain.modules.benchmark_job import BenchmarkJob
Expand All @@ -34,9 +32,8 @@
from aixplain.factories.dataset_factory import DatasetFactory
from aixplain.factories.model_factory import ModelFactory
from aixplain.utils import config
from aixplain.utils.file_utils import _request_with_retry, save_file
from aixplain.utils.file_utils import _request_with_retry
from urllib.parse import urljoin
from warnings import warn


class BenchmarkFactory:
Expand Down Expand Up @@ -117,18 +114,25 @@ def get(cls, benchmark_id: str) -> Benchmark:
logging.info(f"Start service for GET Benchmark - {url} - {headers}")
r = _request_with_retry("get", url, headers=headers)
resp = r.json()
benchmark = cls._create_benchmark_from_response(resp)

except Exception as e:
status_code = 400
if resp is not None and "statusCode" in resp:
status_code = resp["statusCode"]
message = resp["message"]
message = f"Benchmark Creation: Status {status_code} - {message}"
else:
message = f"Benchmark Creation: Unspecified Error"
message = "Benchmark Creation: Unspecified Error"
logging.error(f"Benchmark Creation Failed: {e}")
raise Exception(f"Status {status_code}: {message}")
return benchmark
if 200 <= r.status_code < 300:
benchmark = cls._create_benchmark_from_response(resp)
logging.info(f"Benchmark {benchmark_id} retrieved successfully.")
return benchmark
else:
error_message = f"Benchmark GET Error: Status {r.status_code} - {resp}"
logging.error(error_message)
raise Exception(error_message)

@classmethod
def get_job(cls, job_id: Text) -> BenchmarkJob:
Expand Down Expand Up @@ -189,7 +193,7 @@ def create(cls, name: str, dataset_list: List[Dataset], model_list: List[Model],
"""
payload = {}
try:
url = urljoin(cls.backend_url, f"sdk/benchmarks")
url = urljoin(cls.backend_url, "sdk/benchmarks")
headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"}
payload = {
"name": name,
Expand All @@ -204,12 +208,19 @@ def create(cls, name: str, dataset_list: List[Dataset], model_list: List[Model],
payload = json.dumps(clean_payload)
r = _request_with_retry("post", url, headers=headers, data=payload)
resp = r.json()
logging.info(f"Creating Benchmark Job: Status for {name}: {resp}")
return cls.get(resp["id"])

except Exception as e:
error_message = f"Creating Benchmark Job: Error in Creating Benchmark with payload {payload} : {e}"
logging.error(error_message, exc_info=True)
return None
raise Exception(error_message)

if 200 <= r.status_code < 300:
logging.info(f"Benchmark {name} created successfully.")
return cls.get(resp["id"])
else:
error_message = f"Benchmark Creation Error: Status {r.status_code} - {resp}"
logging.error(error_message)
raise Exception(error_message)

@classmethod
def list_normalization_options(cls, metric: Metric, model: Model) -> List[str]:
Expand All @@ -223,21 +234,28 @@ def list_normalization_options(cls, metric: Metric, model: Model) -> List[str]:
List[str]: List of supported normalization options
"""
try:
url = urljoin(cls.backend_url, f"sdk/benchmarks/normalization-options")
url = urljoin(cls.backend_url, "sdk/benchmarks/normalization-options")
if cls.aixplain_key != "":
headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"}
else:
headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"}
payload = json.dumps({"metricId": metric.id, "modelIds": [model.id]})
r = _request_with_retry("post", url, headers=headers, data=payload)
resp = r.json()
logging.info(f"Listing Normalization Options: Status of listing options: {resp}")
normalization_options = [item["value"] for item in resp]
return normalization_options

except Exception as e:
error_message = f"Listing Normalization Options: Error in getting Normalization Options: {e}"
logging.error(error_message, exc_info=True)
return []
raise Exception(error_message)

if 200 <= r.status_code < 300:
logging.info("Listing Normalization Options: ")
normalization_options = [item["value"] for item in resp]
return normalization_options
else:
error_message = f"Error listing normalization options: Status {r.status_code} - {resp}"
logging.error(error_message)
raise Exception(error_message)

@classmethod
def get_benchmark_job_scores(cls, job_id):
Expand All @@ -255,7 +273,8 @@ def __get_model_name(model_id):
if model.version is not None:
name = f"{name}({model.version})"
return name

benchmarkJob = cls.get_job(job_id)
scores_df = benchmarkJob.get_scores()
scores_df["Model"] = scores_df["Model"].apply(lambda x: __get_model_name(x))
return scores_df
return scores_df
91 changes: 56 additions & 35 deletions aixplain/factories/corpus_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
Corpus Factory Class
"""

import aixplain.utils.config as config
import aixplain.processes.data_onboarding.onboard_functions as onboard_functions
import json
import logging
Expand Down Expand Up @@ -86,12 +85,12 @@ def __from_response(cls, response: Dict) -> Corpus:

try:
license = License(response["license"]["typeId"])
except:
except Exception:
license = None

try:
length = int(response["segmentsCount"])
except:
except Exception:
length = None

corpus = Corpus(
Expand All @@ -116,17 +115,27 @@ def get(cls, corpus_id: Text) -> Corpus:
Returns:
Corpus: Created 'Corpus' object
"""
url = urljoin(cls.backend_url, f"sdk/corpora/{corpus_id}/overview")
if cls.aixplain_key != "":
headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"}
try:
url = urljoin(cls.backend_url, f"sdk/corpora/{corpus_id}/overview")
if cls.aixplain_key != "":
headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"}
else:
headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"}
logging.info(f"Start service for GET Corpus - {url} - {headers}")
r = _request_with_retry("get", url, headers=headers)
resp = r.json()

except Exception as e:
error_message = f"Error retrieving Corpus {corpus_id}: {str(e)}"
logging.error(error_message, exc_info=True)
raise Exception(error_message)
if 200 <= r.status_code < 300:
logging.info(f"Corpus {corpus_id} retrieved successfully.")
return cls.__from_response(resp)
else:
headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"}
logging.info(f"Start service for GET Corpus - {url} - {headers}")
r = _request_with_retry("get", url, headers=headers)
resp = r.json()
if "statusCode" in resp and resp["statusCode"] == 404:
raise Exception(f"Corpus GET Error: Dataset {corpus_id} not found.")
return cls.__from_response(resp)
error_message = f"Corpus GET Error: Status {r.status_code} - {resp}"
logging.error(error_message)
raise Exception(error_message)

@classmethod
def create_asset_from_id(cls, corpus_id: Text) -> Corpus:
Expand Down Expand Up @@ -168,7 +177,7 @@ def list(
else:
headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"}

assert 0 < page_size <= 100, f"Corpus List Error: Page size must be greater than 0 and not exceed 100."
assert 0 < page_size <= 100, "Corpus List Error: Page size must be greater than 0 and not exceed 100."
payload = {"pageSize": page_size, "pageNumber": page_number, "sort": [{"field": "createdAt", "dir": -1}]}

if query is not None:
Expand All @@ -188,26 +197,38 @@ def list(
language = [language]
payload["language"] = [lng.value["language"] for lng in language]

logging.info(f"Start service for POST List Corpus - {url} - {headers} - {json.dumps(payload)}")
r = _request_with_retry("post", url, headers=headers, json=payload)
resp = r.json()
corpora, page_total, total = [], 0, 0
if "results" in resp:
results = resp["results"]
page_total = resp["pageTotal"]
total = resp["total"]
logging.info(f"Response for POST List Corpus - Page Total: {page_total} / Total: {total}")
for corpus in results:
corpus_ = cls.__from_response(corpus)
# add languages
languages = []
for lng in corpus["languages"]:
if "dialect" not in lng:
lng["dialect"] = ""
languages.append(Language(lng))
corpus_.kwargs["languages"] = languages
corpora.append(corpus_)
return {"results": corpora, "page_total": page_total, "page_number": page_number, "total": total}
try:
logging.info(f"Start service for POST List Corpus - {url} - {headers} - {json.dumps(payload)}")
r = _request_with_retry("post", url, headers=headers, json=payload)
resp = r.json()

except Exception as e:
error_message = f"Error listing corpora: {str(e)}"
logging.error(error_message, exc_info=True)
raise Exception(error_message)

if 200 <= r.status_code < 300:
corpora, page_total, total = [], 0, 0
if "results" in resp:
results = resp["results"]
page_total = resp["pageTotal"]
total = resp["total"]
logging.info(f"Response for POST List Corpus - Page Total: {page_total} / Total: {total}")
for corpus in results:
corpus_ = cls.__from_response(corpus)
# add languages
languages = []
for lng in corpus["languages"]:
if "dialect" not in lng:
lng["dialect"] = ""
languages.append(Language(lng))
corpus_.kwargs["languages"] = languages
corpora.append(corpus_)
return {"results": corpora, "page_total": page_total, "page_number": page_number, "total": total}
else:
error_message = f"Corpus List Error: Status {r.status_code} - {resp}"
logging.error(error_message)
raise Exception(error_message)

@classmethod
def get_assets_from_page(
Expand Down Expand Up @@ -245,7 +266,7 @@ def create(
functions: List[Function] = [],
privacy: Privacy = Privacy.PRIVATE,
error_handler: ErrorHandler = ErrorHandler.SKIP,
api_key: Optional[Text] = None
api_key: Optional[Text] = None,
) -> Dict:
"""Asynchronous call to Upload a corpus to the user's dashboard.

Expand Down
72 changes: 45 additions & 27 deletions aixplain/factories/dataset_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
Dataset Factory Class
"""

import aixplain.utils.config as config
import aixplain.processes.data_onboarding.onboard_functions as onboard_functions
import json
import os
Expand Down Expand Up @@ -49,7 +48,6 @@
from typing import Any, Dict, List, Optional, Text, Union
from urllib.parse import urljoin
from uuid import uuid4
from warnings import warn


class DatasetFactory(AssetFactory):
Expand Down Expand Up @@ -122,7 +120,7 @@ def __from_response(cls, response: Dict) -> Dataset:
target_data_list = [data[data_id] for data_id in out["dataIds"]]
data_name = target_data_list[0].name
target_data[data_name] = target_data_list
except:
except Exception:
pass

# process function
Expand Down Expand Up @@ -164,17 +162,27 @@ def get(cls, dataset_id: Text) -> Dataset:
Returns:
Dataset: Created 'Dataset' object
"""
url = urljoin(cls.backend_url, f"sdk/datasets/{dataset_id}/overview")
if cls.aixplain_key != "":
headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"}
try:
url = urljoin(cls.backend_url, f"sdk/datasets/{dataset_id}/overview")
if cls.aixplain_key != "":
headers = {"x-aixplain-key": f"{cls.aixplain_key}", "Content-Type": "application/json"}
else:
headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"}
logging.info(f"Start service for GET Dataset - {url} - {headers}")
r = _request_with_retry("get", url, headers=headers)
resp = r.json()

except Exception as e:
error_message = f"Error retrieving Dataset {dataset_id}: {str(e)}"
logging.error(error_message, exc_info=True)
raise Exception(error_message)
if 200 <= r.status_code < 300:
logging.info(f"Dataset {dataset_id} retrieved successfully.")
return cls.__from_response(resp)
else:
headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"}
logging.info(f"Start service for GET Dataset - {url} - {headers}")
r = _request_with_retry("get", url, headers=headers)
resp = r.json()
if "statusCode" in resp and resp["statusCode"] == 404:
raise Exception(f"Dataset GET Error: Dataset {dataset_id} not found.")
return cls.__from_response(resp)
error_message = f"Dataset GET Error: Status {r.status_code} - {resp}"
logging.error(error_message)
raise Exception(error_message)

@classmethod
def list(
Expand Down Expand Up @@ -211,7 +219,7 @@ def list(
else:
headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"}

assert 0 < page_size <= 100, f"Dataset List Error: Page size must be greater than 0 and not exceed 100."
assert 0 < page_size <= 100, "Dataset List Error: Page size must be greater than 0 and not exceed 100."
payload = {
"pageSize": page_size,
"pageNumber": page_number,
Expand Down Expand Up @@ -245,19 +253,29 @@ def list(
target_languages = [target_languages]
payload["output"]["languages"] = [lng.value["language"] for lng in target_languages]

logging.info(f"Start service for POST List Dataset - {url} - {headers} - {json.dumps(payload)}")
r = _request_with_retry("post", url, headers=headers, json=payload)
resp = r.json()
try:
logging.info(f"Start service for POST List Dataset - {url} - {headers} - {json.dumps(payload)}")
r = _request_with_retry("post", url, headers=headers, json=payload)
resp = r.json()

datasets, page_total, total = [], 0, 0
if "results" in resp:
results = resp["results"]
page_total = resp["pageTotal"]
total = resp["total"]
logging.info(f"Response for POST List Dataset - Page Total: {page_total} / Total: {total}")
for dataset in results:
datasets.append(cls.__from_response(dataset))
return {"results": datasets, "page_total": page_total, "page_number": page_number, "total": total}
except Exception as e:
error_message = f"Error listing datasets: {str(e)}"
logging.error(error_message, exc_info=True)
raise Exception(error_message)
if 200 <= r.status_code < 300:
datasets, page_total, total = [], 0, 0
if "results" in resp:
results = resp["results"]
page_total = resp["pageTotal"]
total = resp["total"]
logging.info(f"Response for POST List Dataset - Page Total: {page_total} / Total: {total}")
for dataset in results:
datasets.append(cls.__from_response(dataset))
return {"results": datasets, "page_total": page_total, "page_number": page_number, "total": total}
else:
error_message = f"Dataset List Error: Status {r.status_code} - {resp}"
logging.error(error_message)
raise Exception(error_message)

@classmethod
def create(
Expand All @@ -282,7 +300,7 @@ def create(
error_handler: ErrorHandler = ErrorHandler.SKIP,
s3_link: Optional[Text] = None,
aws_credentials: Optional[Dict[Text, Text]] = {"AWS_ACCESS_KEY_ID": None, "AWS_SECRET_ACCESS_KEY": None},
api_key: Optional[Text] = None
api_key: Optional[Text] = None,
) -> Dict:
"""Dataset Onboard

Expand Down
Loading