Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
23b2898
Update Finetuner search metadata functional tests (#172)
lucas-aixplain May 2, 2024
208a081
Downgrade dataclasses-json for compatibility (#170)
thiago-aixplain May 2, 2024
a837e1a
Fix model cost parameters (#179)
thiago-aixplain May 10, 2024
754f478
Treat label URLs (#176)
thiago-aixplain May 15, 2024
f1c9935
Add new metric test (#181)
thiago-aixplain Jun 4, 2024
a48ccfd
LLMModel class and parameters (#184)
thiago-aixplain Jun 5, 2024
c7f59ce
Gpus (#185)
mikelam-us-aixplain Jun 5, 2024
16eb2e1
Create and get Pipelines with api key as input parameter (#187)
thiago-aixplain Jun 7, 2024
2849d6f
Merge branch 'test' into development
thiago-aixplain Jun 11, 2024
04246b1
M 6769474660 save pipelines (#191)
thiago-aixplain Jun 17, 2024
73021a7
M 6769474660 save pipelines (#192)
thiago-aixplain Jun 18, 2024
474602b
Solving bug when LLM parameters are set on data (#196)
thiago-aixplain Jun 26, 2024
c471703
Merge branch 'test' into development
thiago-aixplain Jun 26, 2024
3695686
Fix pipeline functional test (#200)
lucas-aixplain Jul 3, 2024
9014061
M 6656407247 agentification (#197)
thiago-aixplain Jul 13, 2024
e9091c2
Fixing circular import in the SDK (#211)
thiago-aixplain Jul 30, 2024
f437815
create model/pipeline tools from AgentFactory (#214)
thiago-aixplain Aug 2, 2024
8457087
Merge branch 'test' into development
thiago-aixplain Aug 6, 2024
03009c6
Set model ID as a parameter (#216)
thiago-aixplain Aug 7, 2024
02f7482
Content inputs to be processed according to the query. (#215)
thiago-aixplain Aug 7, 2024
4947959
ENG-1: programmatic api introduced (#219)
kadirpekel Aug 9, 2024
ef16dd5
Updated image upload tests (#213)
mikelam-us-aixplain Aug 12, 2024
d0ad51d
Eng 217 local path (#220)
thiago-aixplain Aug 13, 2024
dca1a37
Eng 389 fix tests (#222)
thiago-aixplain Aug 13, 2024
d43f67f
Merge branch 'test' into development
thiago-aixplain Aug 13, 2024
b113368
Tool Validation when creating agents (#226)
xainaz Aug 19, 2024
0032947
Eng 398 sdk get users credits - Initial (#232)
xainaz Aug 20, 2024
a567535
Eng 398 sdk get users credits (#234)
thiago-aixplain Aug 20, 2024
e919fab
Removed wallet_factoy.py (#235)
xainaz Aug 21, 2024
9ffe3f7
Merge branch 'test' into development
thiago-aixplain Aug 22, 2024
115bf13
Adding supervisor/planning options into SDK (#233)
thiago-aixplain Aug 22, 2024
3357e56
Adjustments to get user credits (#237)
xainaz Aug 23, 2024
ee76afd
Put conditions inside try statements according to changes required. (…
xainaz Aug 23, 2024
1660f5f
Fixing none credit (#238)
xainaz Aug 27, 2024
ed20ba7
Merge branch 'test' into development
thiago-aixplain Aug 27, 2024
481dab2
Merge branch 'test' into development
thiago-aixplain Aug 27, 2024
9a89f52
Update click dependency (#241)
thiago-aixplain Aug 28, 2024
cb0d313
Added input and output attributes to model (#244)
xainaz Sep 2, 2024
716d898
Eng 467 ai xplain sdk update finetune functional tests to cover all n…
xainaz Sep 3, 2024
50d7c6a
Merge branch 'test' into development
thiago-aixplain Sep 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 39 additions & 17 deletions aixplain/factories/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from aixplain.utils.file_utils import _request_with_retry
from urllib.parse import urljoin
from warnings import warn
from aixplain.enums.function import FunctionInputOutput
from datetime import datetime


class ModelFactory:
Expand Down Expand Up @@ -66,14 +68,26 @@ def _create_model_from_response(cls, response: Dict) -> Model:
if function == Function.TEXT_GENERATION:
ModelClass = LLM

created_at = None
if "createdAt" in response and response["createdAt"]:
created_at = datetime.fromisoformat(response["createdAt"].replace("Z", "+00:00"))
function_id = response["function"]["id"]
function = Function(function_id)
function_io = FunctionInputOutput.get(function_id, None)
input_params = {param["code"]: param for param in function_io["spec"]["params"]}
output_params = {param["code"]: param for param in function_io["spec"]["output"]}

return ModelClass(
response["id"],
response["name"],
supplier=response["supplier"],
api_key=response["api_key"],
cost=response["pricing"],
function=function,
created_at=created_at,
parameters=parameters,
input_params=input_params,
output_params=output_params,
is_subscribed=True if "subscription" in response else False,
version=response["version"]["id"],
)
Expand Down Expand Up @@ -270,7 +284,7 @@ def list_host_machines(cls, api_key: Optional[Text] = None) -> List[Dict]:
for dictionary in response_dicts:
del dictionary["id"]
return response_dicts

@classmethod
def list_gpus(cls, api_key: Optional[Text] = None) -> List[List[Text]]:
"""List GPU names on which you can host your language model.
Expand Down Expand Up @@ -335,7 +349,7 @@ def create_asset_repo(
input_modality: Text,
output_modality: Text,
documentation_url: Optional[Text] = "",
api_key: Optional[Text] = None
api_key: Optional[Text] = None,
) -> Dict:
"""Creates an image repository for this model and registers it in the
platform backend.
Expand All @@ -362,7 +376,7 @@ def create_asset_repo(
function_id = function_dict["id"]
if function_id is None:
raise Exception(f"Invalid function name {function}")
create_url = urljoin(config.BACKEND_URL, f"sdk/models/onboard")
create_url = urljoin(config.BACKEND_URL, "sdk/models/onboard")
logging.debug(f"URL: {create_url}")
if api_key:
headers = {"x-api-key": f"{api_key}", "Content-Type": "application/json"}
Expand All @@ -373,19 +387,14 @@ def create_asset_repo(
"model": {
"name": name,
"description": description,
"connectionType": [
"synchronous"
],
"connectionType": ["synchronous"],
"function": function_id,
"modalities": [
f"{input_modality}-{output_modality}"
],
"modalities": [f"{input_modality}-{output_modality}"],
"documentationUrl": documentation_url,
"sourceLanguage": source_language
"sourceLanguage": source_language,
},
"source": "aixplain-ecr",
"onboardingParams": {
}
"onboardingParams": {},
}
logging.debug(f"Body: {str(payload)}")
response = _request_with_retry("post", create_url, headers=headers, json=payload)
Expand All @@ -412,12 +421,18 @@ def asset_repo_login(cls, api_key: Optional[Text] = None) -> Dict:
else:
headers = {"Authorization": f"Token {config.TEAM_API_KEY}", "Content-Type": "application/json"}
response = _request_with_retry("post", login_url, headers=headers)
print(f"Response: {response}")
response_dict = json.loads(response.text)
return response_dict

@classmethod
def onboard_model(cls, model_id: Text, image_tag: Text, image_hash: Text, host_machine: Optional[Text] = "", api_key: Optional[Text] = None) -> Dict:
def onboard_model(
cls,
model_id: Text,
image_tag: Text,
image_hash: Text,
host_machine: Optional[Text] = "",
api_key: Optional[Text] = None,
) -> Dict:
"""Onboard a model after its image has been pushed to ECR.

Args:
Expand Down Expand Up @@ -446,7 +461,14 @@ def onboard_model(cls, model_id: Text, image_tag: Text, image_hash: Text, host_m
return response

@classmethod
def deploy_huggingface_model(cls, name: Text, hf_repo_id: Text, revision: Optional[Text] = "", hf_token: Optional[Text] = "", api_key: Optional[Text] = None) -> Dict:
def deploy_huggingface_model(
cls,
name: Text,
hf_repo_id: Text,
revision: Optional[Text] = "",
hf_token: Optional[Text] = "",
api_key: Optional[Text] = None,
) -> Dict:
"""Onboards and deploys a Hugging Face large language model.

Args:
Expand Down Expand Up @@ -477,8 +499,8 @@ def deploy_huggingface_model(cls, name: Text, hf_repo_id: Text, revision: Option
"hf_supplier": supplier,
"hf_model_name": model_name,
"hf_token": hf_token,
"revision": revision
}
"revision": revision,
},
}
response = _request_with_retry("post", deploy_url, headers=headers, json=body)
logging.debug(response.text)
Expand Down
22 changes: 20 additions & 2 deletions aixplain/modules/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from urllib.parse import urljoin
from aixplain.utils.file_utils import _request_with_retry
from typing import Union, Optional, Text, Dict
from datetime import datetime


class Model(Asset):
Expand All @@ -48,6 +49,8 @@ class Model(Asset):
backend_url (str): URL of the backend.
pricing (Dict, optional): model price. Defaults to None.
**additional_info: Any additional Model info to be saved
input_params (Dict, optional): input parameters for the function.
output_params (Dict, optional): output parameters for the function.
"""

def __init__(
Expand All @@ -61,6 +64,9 @@ def __init__(
function: Optional[Function] = None,
is_subscribed: bool = False,
cost: Optional[Dict] = None,
created_at: Optional[datetime] = None,
input_params: Optional[Dict] = None,
output_params: Optional[Dict] = None,
**additional_info,
) -> None:
"""Model Init
Expand All @@ -84,6 +90,9 @@ def __init__(
self.backend_url = config.BACKEND_URL
self.function = function
self.is_subscribed = is_subscribed
self.created_at = created_at
self.input_params = input_params
self.output_params = output_params

def to_dict(self) -> Dict:
"""Get the model info as a Dictionary
Expand All @@ -92,7 +101,14 @@ def to_dict(self) -> Dict:
Dict: Model Information
"""
clean_additional_info = {k: v for k, v in self.additional_info.items() if v is not None}
return {"id": self.id, "name": self.name, "supplier": self.supplier, "additional_info": clean_additional_info}
return {
"id": self.id,
"name": self.name,
"supplier": self.supplier,
"additional_info": clean_additional_info,
"input_params": self.input_params,
"output_params": self.output_params,
}

def __repr__(self):
try:
Expand Down Expand Up @@ -257,7 +273,9 @@ def run_async(self, data: Union[Text, Dict], name: Text = "model_process", param
error = "Validation-related error: Please ensure all required fields are provided and correctly formatted."
else:
status_code = str(r.status_code)
error = f"Status {status_code}: Unspecified error: An unspecified error occurred while processing your request."
error = (
f"Status {status_code}: Unspecified error: An unspecified error occurred while processing your request."
)
response = {"status": "FAILED", "error_message": error}
logging.error(f"Error in request for {name} - {r.status_code}: {error}")
except Exception:
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespaces = true

[project]
name = "aiXplain"
version = "0.2.13rc2"
version = "0.2.18"
description = "aiXplain SDK adds AI functions to software."
readme = "README.md"
requires-python = ">=3.5, <4"
Expand Down Expand Up @@ -49,7 +49,7 @@ dependencies = [
"python-dotenv>=1.0.0",
"validators>=0.20.0",
"filetype>=1.2.0",
"click>=7.1.2,<8.0.0",
"click>=7.1.2",
"PyYAML>=6.0.1",
"dataclasses-json>=0.5.2",
"Jinja2==3.1.4",
Expand Down
42 changes: 29 additions & 13 deletions tests/functional/finetune/finetune_functional_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
__author__ = "lucaspavanelli"

"""
Copyright 2022 The aiXplain SDK authors

Expand All @@ -26,6 +25,7 @@
from aixplain.factories import FinetuneFactory
from aixplain.modules.finetune.cost import FinetuneCost
from aixplain.enums import Function, Language
from datetime import datetime, timedelta, timezone

import pytest

Expand All @@ -40,11 +40,6 @@ def read_data(data_path):
return json.load(open(data_path, "r"))


@pytest.fixture(scope="module", params=read_data(RUN_FILE))
def run_input_map(request):
return request.param


@pytest.fixture(scope="module", params=read_data(ESTIMATE_COST_FILE))
def estimate_cost_input_map(request):
return request.param
Expand All @@ -60,11 +55,32 @@ def validate_prompt_input_map(request):
return request.param


def test_end2end(run_input_map):
model = ModelFactory.get(run_input_map["model_id"])
dataset_list = [DatasetFactory.list(query=run_input_map["dataset_name"])["results"][0]]
def pytest_generate_tests(metafunc):
if "input_map" in metafunc.fixturenames:
four_weeks_ago = datetime.now(timezone.utc) - timedelta(weeks=4)
models = ModelFactory.list(function=Function.TEXT_GENERATION, is_finetunable=True)["results"]

recent_models = [
{
"model_name": model.name,
"model_id": model.id,
"dataset_name": "Test text generation dataset",
"inference_data": "Hello!",
"required_dev": True,
"search_metadata": False,
}
for model in models
if model.created_at is not None and model.created_at >= four_weeks_ago
]
recent_models += read_data(RUN_FILE)
metafunc.parametrize("input_map", recent_models)


def test_end2end(input_map):
model = input_map["model_id"]
dataset_list = [DatasetFactory.list(query=input_map["dataset_name"])["results"][0]]
train_percentage, dev_percentage = 100, 0
if run_input_map["required_dev"]:
if input_map["required_dev"]:
train_percentage, dev_percentage = 80, 20
finetune = FinetuneFactory.create(
str(uuid.uuid4()), dataset_list, model, train_percentage=train_percentage, dev_percentage=dev_percentage
Expand All @@ -85,12 +101,12 @@ def test_end2end(run_input_map):
assert finetune_model.check_finetune_status().model_status.value == "onboarded"
time.sleep(30)
print(f"Model dict: {finetune_model.__dict__}")
result = finetune_model.run(run_input_map["inference_data"])
result = finetune_model.run(input_map["inference_data"])
print(f"Result: {result}")
assert result is not None
if run_input_map["search_metadata"]:
if input_map["search_metadata"]:
assert "details" in result
assert len(result["details"]) > 0
assert len(result["details"]) > 0
assert "metadata" in result["details"][0]
assert len(result["details"][0]["metadata"]) > 0
finetune_model.delete()
Expand Down
22 changes: 22 additions & 0 deletions tests/functional/general_assets/asset_functional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,25 @@ def test_llm_instantiation():
"""Test that the LLM model is correctly instantiated."""
models = ModelFactory.list(function=Function.TEXT_GENERATION)["results"]
assert isinstance(models[0], LLM)


def test_model_io():
model_id = "64aee5824d34b1221e70ac07"
model = ModelFactory.get(model_id)

expected_input = {
"text": {
"name": "Text Prompt",
"code": "text",
"required": True,
"isFixed": False,
"dataType": "text",
"dataSubType": "text",
"multipleValues": False,
"defaultValues": [],
}
}
expected_output = {"data": {"name": "Generated Image", "code": "data", "defaultValue": [], "dataType": "image"}}

assert model.input_params == expected_input
assert model.output_params == expected_output
14 changes: 6 additions & 8 deletions tests/unit/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,24 +63,22 @@ def test_failed_poll():
@pytest.mark.parametrize(
"status_code,error_message",
[
(401,"Unauthorized API key: Please verify the spelling of the API key and its current validity."),
(465,"Subscription-related error: Please ensure that your subscription is active and has not expired."),
(475,"Billing-related error: Please ensure you have enough credits to run this model. "),
(401, "Unauthorized API key: Please verify the spelling of the API key and its current validity."),
(465, "Subscription-related error: Please ensure that your subscription is active and has not expired."),
(475, "Billing-related error: Please ensure you have enough credits to run this model. "),
(485, "Supplier-related error: Please ensure that the selected supplier provides the model you are trying to access."),
(495, "Validation-related error: Please ensure all required fields are provided and correctly formatted."),
(501, "Status 501: Unspecified error: An unspecified error occurred while processing your request."),

],
)

def test_run_async_errors(status_code, error_message):
base_url = config.MODELS_RUN_URL
model_id = "model-id"
execute_url = urljoin(base_url, f"execute/{model_id}")

with requests_mock.Mocker() as mock:
mock.post(execute_url, status_code=status_code)
test_model = Model(id=model_id, name="Test Model",url=base_url)
test_model = Model(id=model_id, name="Test Model", url=base_url)
response = test_model.run_async(data="input_data")
assert response["status"] == "FAILED"
assert response["error_message"] == error_message
assert response["error_message"] == error_message