Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions aixplain/factories/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from urllib.parse import urljoin
from warnings import warn
from aixplain.enums.function import FunctionInputOutput
from datetime import datetime


class ModelFactory:
Expand Down Expand Up @@ -67,6 +68,9 @@ def _create_model_from_response(cls, response: Dict) -> Model:
if function == Function.TEXT_GENERATION:
ModelClass = LLM

created_at = None
if "createdAt" in response and response["createdAt"]:
created_at = datetime.fromisoformat(response["createdAt"].replace("Z", "+00:00"))
function_id = response["function"]["id"]
function = Function(function_id)
function_io = FunctionInputOutput.get(function_id, None)
Expand All @@ -80,6 +84,7 @@ def _create_model_from_response(cls, response: Dict) -> Model:
api_key=response["api_key"],
cost=response["pricing"],
function=function,
created_at=created_at,
parameters=parameters,
input_params=input_params,
output_params=output_params,
Expand Down
20 changes: 16 additions & 4 deletions aixplain/modules/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from urllib.parse import urljoin
from aixplain.utils.file_utils import _request_with_retry
from typing import Union, Optional, Text, Dict
from datetime import datetime


class Model(Asset):
Expand Down Expand Up @@ -63,6 +64,7 @@ def __init__(
function: Optional[Function] = None,
is_subscribed: bool = False,
cost: Optional[Dict] = None,
created_at: Optional[datetime] = None,
input_params: Optional[Dict] = None,
output_params: Optional[Dict] = None,
**additional_info,
Expand All @@ -88,8 +90,9 @@ def __init__(
self.backend_url = config.BACKEND_URL
self.function = function
self.is_subscribed = is_subscribed
self.input_params = input_params
self.output_params = output_params
self.created_at = created_at
self.input_params = input_params
self.output_params = output_params

def to_dict(self) -> Dict:
"""Get the model info as a Dictionary
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check whether this to_dict method works with a datetime attribute.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to work, this is what it returns:
{'id': '64e615671567f848804985e1', 'name': 'GPT2', 'supplier': <Supplier.OPENAI: {'id': 1777, 'name': 'OpenAI', 'code': 'openai'}>, 'additional_info': {'parameters': {'language': ['en']}}, 'createdAt': datetime.datetime(2023, 8, 30, 10, 53, 10, 827000, tzinfo=datetime.timezone.utc)}
I added createdAt to to_dict

Expand All @@ -98,7 +101,14 @@ def to_dict(self) -> Dict:
Dict: Model Information
"""
clean_additional_info = {k: v for k, v in self.additional_info.items() if v is not None}
return {"id": self.id, "name": self.name, "supplier": self.supplier, "additional_info": clean_additional_info, "input_params": self.input_params,"output_params": self.output_params,}
return {
"id": self.id,
"name": self.name,
"supplier": self.supplier,
"additional_info": clean_additional_info,
"input_params": self.input_params,
"output_params": self.output_params,
}

def __repr__(self):
try:
Expand Down Expand Up @@ -263,7 +273,9 @@ def run_async(self, data: Union[Text, Dict], name: Text = "model_process", param
error = "Validation-related error: Please ensure all required fields are provided and correctly formatted."
else:
status_code = str(r.status_code)
error = f"Status {status_code}: Unspecified error: An unspecified error occurred while processing your request."
error = (
f"Status {status_code}: Unspecified error: An unspecified error occurred while processing your request."
)
response = {"status": "FAILED", "error_message": error}
logging.error(f"Error in request for {name} - {r.status_code}: {error}")
except Exception:
Expand Down
42 changes: 29 additions & 13 deletions tests/functional/finetune/finetune_functional_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
__author__ = "lucaspavanelli"

"""
Copyright 2022 The aiXplain SDK authors

Expand All @@ -26,6 +25,7 @@
from aixplain.factories import FinetuneFactory
from aixplain.modules.finetune.cost import FinetuneCost
from aixplain.enums import Function, Language
from datetime import datetime, timedelta, timezone

import pytest

Expand All @@ -40,11 +40,6 @@ def read_data(data_path):
return json.load(open(data_path, "r"))


@pytest.fixture(scope="module", params=read_data(RUN_FILE))
def run_input_map(request):
return request.param


@pytest.fixture(scope="module", params=read_data(ESTIMATE_COST_FILE))
def estimate_cost_input_map(request):
return request.param
Expand All @@ -60,11 +55,32 @@ def validate_prompt_input_map(request):
return request.param


def test_end2end(run_input_map):
model = ModelFactory.get(run_input_map["model_id"])
dataset_list = [DatasetFactory.list(query=run_input_map["dataset_name"])["results"][0]]
def pytest_generate_tests(metafunc):
if "input_map" in metafunc.fixturenames:
four_weeks_ago = datetime.now(timezone.utc) - timedelta(weeks=4)
models = ModelFactory.list(function=Function.TEXT_GENERATION, is_finetunable=True)["results"]

recent_models = [
{
"model_name": model.name,
"model_id": model.id,
"dataset_name": "Test text generation dataset",
"inference_data": "Hello!",
"required_dev": True,
"search_metadata": False,
}
for model in models
if model.created_at is not None and model.created_at >= four_weeks_ago
]
recent_models += read_data(RUN_FILE)
metafunc.parametrize("input_map", recent_models)


def test_end2end(input_map):
model = input_map["model_id"]
dataset_list = [DatasetFactory.list(query=input_map["dataset_name"])["results"][0]]
train_percentage, dev_percentage = 100, 0
if run_input_map["required_dev"]:
if input_map["required_dev"]:
train_percentage, dev_percentage = 80, 20
finetune = FinetuneFactory.create(
str(uuid.uuid4()), dataset_list, model, train_percentage=train_percentage, dev_percentage=dev_percentage
Expand All @@ -85,12 +101,12 @@ def test_end2end(run_input_map):
assert finetune_model.check_finetune_status().model_status.value == "onboarded"
time.sleep(30)
print(f"Model dict: {finetune_model.__dict__}")
result = finetune_model.run(run_input_map["inference_data"])
result = finetune_model.run(input_map["inference_data"])
print(f"Result: {result}")
assert result is not None
if run_input_map["search_metadata"]:
if input_map["search_metadata"]:
assert "details" in result
assert len(result["details"]) > 0
assert len(result["details"]) > 0
assert "metadata" in result["details"][0]
assert len(result["details"][0]["metadata"]) > 0
finetune_model.delete()
Expand Down