diff --git a/aixplain/factories/model_factory/utils.py b/aixplain/factories/model_factory/utils.py index ca3e1eec..ed3409a3 100644 --- a/aixplain/factories/model_factory/utils.py +++ b/aixplain/factories/model_factory/utils.py @@ -118,6 +118,7 @@ def create_model_from_response(response: Dict) -> Model: supports_streaming=response.get("supportsStreaming", False), status=status, function_type=function_type, + attributes=attributes, **additional_kwargs, ) diff --git a/aixplain/factories/tool_factory.py b/aixplain/factories/tool_factory.py index 9db2a9ea..8f746422 100644 --- a/aixplain/factories/tool_factory.py +++ b/aixplain/factories/tool_factory.py @@ -4,12 +4,12 @@ from aixplain.factories.model_factory.mixins import ModelGetterMixin, ModelListMixin from aixplain.modules.model import Model from aixplain.modules.model.index_model import IndexModel -from aixplain.modules.model.integration import Integration +from aixplain.modules.model.integration import Integration, AuthenticationSchema from aixplain.modules.model.integration import BaseAuthenticationParams from aixplain.factories.index_factory.utils import BaseIndexParams, AirParams, VectaraParams, ZeroEntropyParams, GraphRAGParams from aixplain.enums.index_stores import IndexStores from aixplain.modules.model.utility_model import BaseUtilityModelParams -from typing import Optional, Text, Union +from typing import Optional, Text, Union, Dict from aixplain.enums import ResponseStatus from aixplain.utils import config @@ -22,6 +22,8 @@ def create( cls, integration: Optional[Union[Text, Model]] = None, params: Optional[Union[BaseUtilityModelParams, BaseIndexParams, BaseAuthenticationParams]] = None, + authentication_schema: Optional[AuthenticationSchema] = None, + data: Optional[Dict] = None, **kwargs, ) -> Model: """Factory method to create indexes, script models and connections @@ -143,7 +145,17 @@ def add(a: int, b: int) -> int: assert ( connector.function_type == FunctionType.INTEGRATION ), f"The model you are trying to connect ({connector.id}) to is not a connector." - response = connector.connect(params) + + assert authentication_schema is not None, "Please provide the authentication schema to use (authentication_schema parameter)" + assert isinstance(authentication_schema, AuthenticationSchema), "authentication_schema must be an instance of AuthenticationSchema" + + auth_data = data if data is not None else {} + if not auth_data: + for key, value in kwargs.items(): + if key not in ['name', 'connector_id']: + auth_data[key] = value + + response = connector.connect(authentication_schema, params, auth_data) assert response.status == ResponseStatus.SUCCESS, f"Failed to connect to {connector.id} - {response.error_message}" connection = cls.get(response.data["id"]) if "redirectURL" in response.data: diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 2783ffc3..69c9f002 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -732,13 +732,13 @@ class Response(BaseModel): def test_agent_with_action_tool(): from aixplain.modules.model.integration import AuthenticationSchema - connector = ModelFactory.get("67eff5c0e05614297caeef98") + connector = ModelFactory.get("686432941223092cb4294d3f") # connect - response = connector.connect(authentication_schema=AuthenticationSchema.BEARER, token=os.getenv("SLACK_TOKEN")) + response = connector.connect(authentication_schema=AuthenticationSchema.BEARER_TOKEN, data={"token": os.getenv("SLACK_TOKEN")}) connection_id = response.data["id"] connection = ModelFactory.get(connection_id) - connection.action_scope = [action for action in connection.actions if action.code == "SLACK_CHAT_POST_MESSAGE"] + connection.action_scope = [action for action in connection.actions if action.code == "SLACK_SENDS_A_MESSAGE_TO_A_SLACK_CHANNEL"] agent = AgentFactory.create( name="Test Agent", @@ -757,5 +757,5 @@ def test_agent_with_action_tool(): assert response is not None assert response["status"].lower() == "success" assert "helsinki" in response.data.output.lower() - assert "SLACK_CHAT_POST_MESSAGE" in [step["tool"] for step in response.data.intermediate_steps[0]["tool_steps"]] + assert "SLACK_SENDS_A_MESSAGE_TO_A_SLACK_CHANNEL" in [step["tool"] for step in response.data.intermediate_steps[0]["tool_steps"]] connection.delete() diff --git a/tests/functional/team_agent/team_agent_functional_test.py b/tests/functional/team_agent/team_agent_functional_test.py index f3852f60..5659bfa8 100644 --- a/tests/functional/team_agent/team_agent_functional_test.py +++ b/tests/functional/team_agent/team_agent_functional_test.py @@ -445,13 +445,13 @@ class Response(BaseModel): def test_team_agent_with_slack_connector(): from aixplain.modules.model.integration import AuthenticationSchema - connector = ModelFactory.get("67eff5c0e05614297caeef98") + connector = ModelFactory.get("686432941223092cb4294d3f") # connect - response = connector.connect(authentication_schema=AuthenticationSchema.BEARER, token=os.getenv("SLACK_TOKEN")) + response = connector.connect(authentication_schema=AuthenticationSchema.BEARER_TOKEN, data={"token": os.getenv("SLACK_TOKEN")}) connection_id = response.data["id"] connection = ModelFactory.get(connection_id) - connection.action_scope = [action for action in connection.actions if action.code == "SLACK_CHAT_POST_MESSAGE"] + connection.action_scope = [action for action in connection.actions if action.code == "SLACK_SENDS_A_MESSAGE_TO_A_SLACK_CHANNEL"] agent = AgentFactory.create( name="Test Agent",