diff --git a/aixplain/factories/model_factory.py b/aixplain/factories/model_factory.py index 5df7c924..fc64f1a7 100644 --- a/aixplain/factories/model_factory.py +++ b/aixplain/factories/model_factory.py @@ -221,8 +221,8 @@ def _get_assets_from_page( @classmethod def list( cls, + function: Function, query: Optional[Text] = "", - function: Optional[Function] = None, suppliers: Optional[Union[Supplier, List[Supplier]]] = None, source_languages: Optional[Union[Language, List[Language]]] = None, target_languages: Optional[Union[Language, List[Language]]] = None, @@ -236,7 +236,7 @@ def list( """Gets the first k given models based on the provided task and language filters Args: - function (Optional[Function], optional): function filter. Defaults to None. + function (Function): function filter. source_languages (Optional[Union[Language, List[Language]]], optional): language filter of input data. Defaults to None. target_languages (Optional[Union[Language, List[Language]]], optional): language filter of output data. Defaults to None. is_finetunable (Optional[bool], optional): can be finetuned or not. Defaults to None. diff --git a/aixplain/utils/config.py b/aixplain/utils/config.py index 59805c60..03bbdccf 100644 --- a/aixplain/utils/config.py +++ b/aixplain/utils/config.py @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) BACKEND_URL = os.getenv("BACKEND_URL", "https://platform-api.aixplain.com") -MODELS_RUN_URL = os.getenv("MODELS_RUN_URL", "https://models.aixplain.com") +MODELS_RUN_URL = os.getenv("MODELS_RUN_URL", "https://models.aixplain.com/api/v1/execute") # GET THE API KEY FROM CMD TEAM_API_KEY = os.getenv("TEAM_API_KEY", "") AIXPLAIN_API_KEY = os.getenv("AIXPLAIN_API_KEY", "") diff --git a/tests/functional/general_assets/asset_functional_test.py b/tests/functional/general_assets/asset_functional_test.py index b0d8f6ef..266b04ea 100644 --- a/tests/functional/general_assets/asset_functional_test.py +++ b/tests/functional/general_assets/asset_functional_test.py @@ -33,7 +33,10 @@ def __get_asset_factory(asset_name): @pytest.mark.parametrize("asset_name", ["model", "dataset", "metric"]) def test_list(asset_name): AssetFactory = __get_asset_factory(asset_name) - asset_list = AssetFactory.list() + if asset_name == "model": + asset_list = AssetFactory.list(function=Function.TRANSLATION) + else: + asset_list = AssetFactory.list() assert asset_list["page_total"] == len(asset_list["results"]) @@ -62,7 +65,7 @@ def test_model_function(): def test_model_supplier(): desired_suppliers = [Supplier.GOOGLE] - models = ModelFactory.list(suppliers=desired_suppliers)["results"] + models = ModelFactory.list(suppliers=desired_suppliers, function=Function.TRANSLATION)["results"] for model in models: assert model.supplier.value in [desired_supplier.value for desired_supplier in desired_suppliers] @@ -89,14 +92,14 @@ def test_model_sort(): def test_model_ownership(): - models = ModelFactory.list(ownership=OwnershipType.SUBSCRIBED)["results"] + models = ModelFactory.list(ownership=OwnershipType.SUBSCRIBED, function=Function.TRANSLATION)["results"] for model in models: assert model.is_subscribed is True def test_model_query(): query = "Mongo" - models = ModelFactory.list(query=query)["results"] + models = ModelFactory.list(query=query, function=Function.TRANSLATION)["results"] for model in models: assert query in model.name