Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions aixplain/enums/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@
from .privacy import Privacy
from .storage_type import StorageType
from .supplier import Supplier
from .sort_by import SortBy
from .sort_order import SortOrder
30 changes: 30 additions & 0 deletions aixplain/enums/sort_by.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
__author__ = "aiXplain"

"""
Copyright 2023 The aiXplain SDK authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

Author: aiXplain team
Date: March 20th 2023
Description:
Sort By Enum
"""

from enum import Enum


class SortBy(Enum):
CREATION_DATE = "createdAt"
PRICE = "normalizedPrice"
POPULARITY = "totalSubscribed"
29 changes: 29 additions & 0 deletions aixplain/enums/sort_order.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
__author__ = "aiXplain"

"""
Copyright 2023 The aiXplain SDK authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

Author: aiXplain team
Date: March 20th 2023
Description:
Sort By Enum
"""

from enum import Enum


class SortOrder(Enum):
ASCENDING = 1
DESCENDING = -1
12 changes: 11 additions & 1 deletion aixplain/factories/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import json
import logging
from aixplain.modules.model import Model
from aixplain.enums import Function, Language, OwnershipType, Supplier
from aixplain.enums import Function, Language, OwnershipType, Supplier, SortBy, SortOrder
from aixplain.utils import config
from aixplain.utils.file_utils import _request_with_retry
from urllib.parse import urljoin
Expand Down Expand Up @@ -130,6 +130,8 @@ def _get_assets_from_page(
target_languages: Union[Language, List[Language]],
is_finetunable: bool = None,
ownership: Optional[Tuple[OwnershipType, List[OwnershipType]]] = None,
sort_by: Optional[SortBy] = None,
sort_order: SortOrder = SortOrder.ASCENDING,
) -> List[Model]:
try:
url = urljoin(cls.backend_url, f"sdk/models/paginate")
Expand All @@ -146,6 +148,7 @@ def _get_assets_from_page(
if isinstance(ownership, OwnershipType) is True:
ownership = [ownership]
filter_params["ownership"] = [ownership_.value for ownership_ in ownership]

lang_filter_params = []
if source_languages is not None:
if isinstance(source_languages, Language):
Expand All @@ -162,6 +165,8 @@ def _get_assets_from_page(
if function == Function.TRANSLATION:
code = "targetlanguage"
lang_filter_params.append({"code": code, "value": target_languages[0].value["language"]})
if sort_by is not None:
filter_params["sort"] = [{"dir": sort_order.value, "field": sort_by.value}]
if len(lang_filter_params) != 0:
filter_params["ioFilter"] = lang_filter_params
if cls.aixplain_key != "":
Expand Down Expand Up @@ -191,6 +196,8 @@ def list(
target_languages: Optional[Union[Language, List[Language]]] = None,
is_finetunable: Optional[bool] = None,
ownership: Optional[Tuple[OwnershipType, List[OwnershipType]]] = None,
sort_by: Optional[SortBy] = None,
sort_order: SortOrder = SortOrder.ASCENDING,
page_number: int = 0,
page_size: int = 20,
) -> List[Model]:
Expand All @@ -202,6 +209,7 @@ def list(
target_languages (Optional[Union[Language, List[Language]]], optional): language filter of output data. Defaults to None.
is_finetunable (Optional[bool], optional): can be finetuned or not. Defaults to None.
ownership (Optional[Tuple[OwnershipType, List[OwnershipType]]], optional): Ownership filters (e.g. SUBSCRIBED, OWNER). Defaults to None.
sort_by (Optional[SortBy], optional): sort the retrived models by a specific attribute,
page_number (int, optional): page number. Defaults to 0.
page_size (int, optional): page size. Defaults to 20.

Expand All @@ -219,6 +227,8 @@ def list(
target_languages,
is_finetunable,
ownership,
sort_by,
sort_order,
)
return {
"results": models,
Expand Down
23 changes: 22 additions & 1 deletion tests/functional/general_assets/asset_functional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
load_dotenv()
from aixplain.factories import ModelFactory, DatasetFactory, MetricFactory, PipelineFactory
from pathlib import Path
from aixplain.enums import Function, OwnershipType, Supplier
from aixplain.enums import Function, Language, OwnershipType, Supplier, SortBy, SortOrder

import pytest

Expand Down Expand Up @@ -63,6 +63,27 @@ def test_model_supplier():
assert model.supplier.value in [desired_supplier.value for desired_supplier in desired_suppliers]


def test_model_sort():
function = Function.TRANSLATION
src_language = Language.Portuguese
trg_language = Language.English

models = ModelFactory.list(
function=function,
source_languages=src_language,
target_languages=trg_language,
sort_by=SortBy.PRICE,
sort_order=SortOrder.DESCENDING,
)["results"]
for idx in range(1, len(models)):
prev_model = models[idx - 1]
model = models[idx]

prev_model_price = prev_model.additional_info["pricing"]["price"]
model_price = model.additional_info["pricing"]["price"]
assert prev_model_price >= model_price


def test_model_ownership():
models = ModelFactory.list(ownership=OwnershipType.SUBSCRIBED)["results"]
for model in models:
Expand Down