Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aixplain/factories/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ def _get_assets_from_page(
@classmethod
def list(
cls,
function: Function,
query: Optional[Text] = "",
function: Optional[Function] = None,
suppliers: Optional[Union[Supplier, List[Supplier]]] = None,
source_languages: Optional[Union[Language, List[Language]]] = None,
target_languages: Optional[Union[Language, List[Language]]] = None,
Expand All @@ -236,7 +236,7 @@ def list(
"""Gets the first k given models based on the provided task and language filters

Args:
function (Optional[Function], optional): function filter. Defaults to None.
function (Function): function filter.
source_languages (Optional[Union[Language, List[Language]]], optional): language filter of input data. Defaults to None.
target_languages (Optional[Union[Language, List[Language]]], optional): language filter of output data. Defaults to None.
is_finetunable (Optional[bool], optional): can be finetuned or not. Defaults to None.
Expand Down
2 changes: 1 addition & 1 deletion aixplain/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
logger = logging.getLogger(__name__)

BACKEND_URL = os.getenv("BACKEND_URL", "https://platform-api.aixplain.com")
MODELS_RUN_URL = os.getenv("MODELS_RUN_URL", "https://models.aixplain.com")
MODELS_RUN_URL = os.getenv("MODELS_RUN_URL", "https://models.aixplain.com/api/v1/execute")
# GET THE API KEY FROM CMD
TEAM_API_KEY = os.getenv("TEAM_API_KEY", "")
AIXPLAIN_API_KEY = os.getenv("AIXPLAIN_API_KEY", "")
Expand Down
11 changes: 7 additions & 4 deletions tests/functional/general_assets/asset_functional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ def __get_asset_factory(asset_name):
@pytest.mark.parametrize("asset_name", ["model", "dataset", "metric"])
def test_list(asset_name):
AssetFactory = __get_asset_factory(asset_name)
asset_list = AssetFactory.list()
if asset_name == "model":
asset_list = AssetFactory.list(function=Function.TRANSLATION)
else:
asset_list = AssetFactory.list()
assert asset_list["page_total"] == len(asset_list["results"])


Expand Down Expand Up @@ -62,7 +65,7 @@ def test_model_function():

def test_model_supplier():
desired_suppliers = [Supplier.GOOGLE]
models = ModelFactory.list(suppliers=desired_suppliers)["results"]
models = ModelFactory.list(suppliers=desired_suppliers, function=Function.TRANSLATION)["results"]
for model in models:
assert model.supplier.value in [desired_supplier.value for desired_supplier in desired_suppliers]

Expand All @@ -89,14 +92,14 @@ def test_model_sort():


def test_model_ownership():
models = ModelFactory.list(ownership=OwnershipType.SUBSCRIBED)["results"]
models = ModelFactory.list(ownership=OwnershipType.SUBSCRIBED, function=Function.TRANSLATION)["results"]
for model in models:
assert model.is_subscribed is True


def test_model_query():
query = "Mongo"
models = ModelFactory.list(query=query)["results"]
models = ModelFactory.list(query=query, function=Function.TRANSLATION)["results"]
for model in models:
assert query in model.name

Expand Down