diff --git a/aixplain/enums/function_type.py b/aixplain/enums/function_type.py index b9655813..465e0e31 100644 --- a/aixplain/enums/function_type.py +++ b/aixplain/enums/function_type.py @@ -32,6 +32,5 @@ class FunctionType(Enum): METRIC = "metric" SEARCH = "search" INTEGRATION = "connector" - CONNECTION = "connection" - MCPSERVER = 'mcpserver' - MCPCONNECTION = "mcpconnection" + MCP_CONNECTION = "mcpconnection" + MCPSERVER = "mcpserver" diff --git a/aixplain/factories/model_factory/utils.py b/aixplain/factories/model_factory/utils.py index 5021ae18..9f32e51f 100644 --- a/aixplain/factories/model_factory/utils.py +++ b/aixplain/factories/model_factory/utils.py @@ -5,6 +5,7 @@ from aixplain.modules.model.index_model import IndexModel from aixplain.modules.model.integration import Integration from aixplain.modules.model.connection import ConnectionTool +from aixplain.modules.model.mcp_connection import MCPConnection from aixplain.modules.model.utility_model import UtilityModel from aixplain.modules.model.utility_model import UtilityModelInput from aixplain.enums import DataType, Function, FunctionType, Language, OwnershipType, Supplier, SortBy, SortOrder, AssetStatus @@ -70,8 +71,10 @@ def create_model_from_response(response: Dict) -> Model: ModelClass = IndexModel elif function_type == FunctionType.INTEGRATION: ModelClass = Integration - elif function_type == FunctionType.CONNECTION or function_type == FunctionType.MCPCONNECTION: + elif function_type == FunctionType.CONNECTION : ModelClass = ConnectionTool + elif function_type == FunctionType.MCP_CONNECTION: + ModelClass = MCPConnection elif function == Function.UTILITIES: ModelClass = UtilityModel inputs = [ diff --git a/aixplain/modules/model/connection.py b/aixplain/modules/model/connection.py index 430289ea..d77fc565 100644 --- a/aixplain/modules/model/connection.py +++ b/aixplain/modules/model/connection.py @@ -53,7 +53,10 @@ def __init__( scope (Text, optional): action scope of the connection. Defaults to None. **additional_info: Any additional Model info to be saved """ - assert function_type == FunctionType.CONNECTION or function_type == FunctionType.MCPCONNECTION, "Connection only supports connection function" + assert ( + function_type == FunctionType.CONNECTION or function_type == FunctionType.MCP_CONNECTION + ), "Connection only supports connection function" + super().__init__( id=id, name=name, diff --git a/aixplain/modules/model/integration.py b/aixplain/modules/model/integration.py index 69e831d5..081c8cbc 100644 --- a/aixplain/modules/model/integration.py +++ b/aixplain/modules/model/integration.py @@ -94,6 +94,9 @@ def connect(self, authentication_schema: AuthenticationSchema, args: Optional[Ba id: Connection ID (retrieve it with ModelFactory.get(id)) redirectUrl: Redirect URL to complete the connection (only for OAuth2) """ + if self.id == "68549a33ba00e44f357896f1": + return self.run({"data": kwargs.get("data")}) + if args is None: args = build_connector_params(**kwargs) @@ -131,6 +134,9 @@ def connect(self, authentication_schema: AuthenticationSchema, args: Optional[Ba f"Before using the tool, please visit the following URL to complete the connection: {response.data['redirectURL']}" ) return response + else: + raise ValueError(f"Invalid authentication schema: {authentication_schema}") + def __repr__(self): diff --git a/aixplain/modules/model/mcp_connection.py b/aixplain/modules/model/mcp_connection.py new file mode 100644 index 00000000..19155694 --- /dev/null +++ b/aixplain/modules/model/mcp_connection.py @@ -0,0 +1,102 @@ +from aixplain.enums import Function, Supplier, FunctionType, ResponseStatus +from aixplain.modules.model.connection import ConnectionTool +from aixplain.modules.model import Model +from typing import Text, Optional, Union, Dict, List + + +class ConnectAction: + name: Text + description: Text + code: Optional[Text] = None + inputs: Optional[Dict] = None + + def __init__(self, name: Text, description: Text, code: Optional[Text] = None, inputs: Optional[Dict] = None): + self.name = name + self.description = description + self.code = code + self.inputs = inputs + + def __repr__(self): + return f"Action(code={self.code}, name={self.name})" + + +class MCPConnection(ConnectionTool): + actions: List[ConnectAction] + action_scope: Optional[List[ConnectAction]] = None + + def __init__( + self, + id: Text, + name: Text, + description: Text = "", + api_key: Optional[Text] = None, + supplier: Union[Dict, Text, Supplier, int] = "aiXplain", + version: Optional[Text] = None, + function: Optional[Function] = None, + is_subscribed: bool = False, + cost: Optional[Dict] = None, + function_type: Optional[FunctionType] = FunctionType.CONNECTION, + **additional_info, + ) -> None: + """Connection Init + + Args: + id (Text): ID of the Model + name (Text): Name of the Model + description (Text, optional): description of the model. Defaults to "". + api_key (Text, optional): API key of the Model. Defaults to None. + supplier (Union[Dict, Text, Supplier, int], optional): supplier of the asset. Defaults to "aiXplain". + version (Text, optional): version of the model. Defaults to "1.0". + function (Function, optional): model AI function. Defaults to None. + is_subscribed (bool, optional): Is the user subscribed. Defaults to False. + cost (Dict, optional): model price. Defaults to None. + scope (Text, optional): action scope of the connection. Defaults to None. + **additional_info: Any additional Model info to be saved + """ + assert function_type == FunctionType.MCP_CONNECTION, "Connection only supports mcp connection function" + super().__init__( + id=id, + name=name, + description=description, + supplier=supplier, + version=version, + cost=cost, + function=function, + is_subscribed=is_subscribed, + api_key=api_key, + function_type=function_type, + **additional_info, + ) + + def _get_actions(self): + response = Model.run(self, {"action": "LIST_TOOLS", "data": " "}) + if response.status == ResponseStatus.SUCCESS: + return [ + ConnectAction(name=action["name"], description=action["description"], code=action["name"]) + for action in response.data + ] + raise Exception( + f"It was not possible to get the actions for the connection {self.id}. Error {response.error_code}: {response.error_message}" + ) + + def get_action_inputs(self, action: Union[ConnectAction, Text]): + if action.inputs: + return action.inputs + + if isinstance(action, ConnectAction): + action = action.code + + response = Model.run(self, {"action": "LIST_TOOLS", "data": {"actions": [action]}}) + if response.status == ResponseStatus.SUCCESS: + try: + inputs = {inp["code"]: inp for inp in response.data[0]["inputs"]} + action_idx = next((i for i, a in enumerate(self.actions) if a.code == action), None) + if action_idx is not None: + self.actions[action_idx].inputs = inputs + return inputs + except Exception as e: + raise Exception(f"It was not possible to get the inputs for the action {action}. Error {e}") + + raise Exception( + f"It was not possible to get the inputs for the action {action}. Error {response.error_code}: {response.error_message}" + ) diff --git a/tests/functional/agent/agent_functional_test.py b/tests/functional/agent/agent_functional_test.py index 69c9f002..db11222a 100644 --- a/tests/functional/agent/agent_functional_test.py +++ b/tests/functional/agent/agent_functional_test.py @@ -759,3 +759,35 @@ def test_agent_with_action_tool(): assert "helsinki" in response.data.output.lower() assert "SLACK_SENDS_A_MESSAGE_TO_A_SLACK_CHANNEL" in [step["tool"] for step in response.data.intermediate_steps[0]["tool_steps"]] connection.delete() + + +def test_agent_with_mcp_tool(): + connector = ModelFactory.get("68549a33ba00e44f357896f1") + # connect + response = connector.connect( + data="https://mcp.zapier.com/api/mcp/s/OTJiMjVlYjEtMGE4YS00OTVjLWIwMGYtZDJjOGVkNTc4NjFkOjI0MTNjNzg5LWZlNGMtNDZmNC05MDhmLWM0MGRlNDU4ZmU1NA==/mcp" + ) + connection_id = response.data["id"] + connection = ModelFactory.get(connection_id) + action_name = "SLACK_SEND_CHANNEL_MESSAGE".lower() + connection.action_scope = [action for action in connection.actions if action.code == action_name] + + agent = AgentFactory.create( + name="Test Agent", + description="This agent is used to send messages to Slack", + instructions="You are a helpful assistant that can send messages to Slack. You MUST use the tool to send the message.", + llm_id="669a63646eb56306647e1091", + tools=[ + connection, + ], + ) + + response = agent.run( + "Send what is the capital of Finland on Slack to channel of #modelserving-alerts-testing. Add the name of the capital in the final answer." + ) + assert response is not None + assert response["status"].lower() == "success" + assert "helsinki" in response.data.output.lower() + assert action_name in [step["tool"] for step in response.data.intermediate_steps[0]["tool_steps"]] + connection.delete() + diff --git a/tests/functional/model/run_connect_model_test.py b/tests/functional/model/run_connect_model_test.py index 3f89b4ad..2b3124b7 100644 --- a/tests/functional/model/run_connect_model_test.py +++ b/tests/functional/model/run_connect_model_test.py @@ -3,6 +3,7 @@ from aixplain.factories import ModelFactory from aixplain.modules.model.integration import Integration, AuthenticationSchema from aixplain.modules.model.connection import ConnectionTool +from aixplain.modules.model.mcp_connection import MCPConnection def test_run_connect_model(): @@ -28,3 +29,29 @@ def test_run_connect_model(): action = action[0] response = connection.run(action, {"text": "This is a test!", "channel": "C084G435LR5"}) assert response.status == ResponseStatus.SUCCESS + + +def test_run_mcp_connect_model(): + # get slack connector + connector = ModelFactory.get("68549a33ba00e44f357896f1") + + assert isinstance(connector, Integration) + assert connector.id == "68549a33ba00e44f357896f1" + assert connector.name == "" + + url = "https://mcp.zapier.com/api/mcp/s/OTJiMjVlYjEtMGE4YS00OTVjLWIwMGYtZDJjOGVkNTc4NjFkOjI0MTNjNzg5LWZlNGMtNDZmNC05MDhmLWM0MGRlNDU4ZmU1NA==" + response = connector.connect(data=url) + assert response.status == ResponseStatus.SUCCESS + assert "id" in response.data + connection_id = response.data["id"] + # get slack connection + connection = ModelFactory.get(connection_id) + assert isinstance(connection, MCPConnection) + assert connection.id == connection_id + assert connection.actions is not None + + action = [action for action in connection.actions if action.code == "SLACK_CHAT_POST_MESSAGE"] + assert len(action) > 0 + action = action[0] + response = connection.run(action, {"text": "This is a test!", "channel": "C084G435LR5"}) + assert response.status == ResponseStatus.SUCCESS diff --git a/tests/unit/model_test.py b/tests/unit/model_test.py index 9abc6fe0..94c7201c 100644 --- a/tests/unit/model_test.py +++ b/tests/unit/model_test.py @@ -34,7 +34,7 @@ from aixplain.modules.model.llm_model import LLM from aixplain.modules.model.index_model import IndexModel from aixplain.modules.model.utility_model import UtilityModel -from aixplain.modules.model.integration import Integration, AuthenticationSchema +from aixplain.modules.model.integration import Integration, AuthenticationSchema, build_connector_params from aixplain.modules.model.connection import ConnectionTool, ConnectAction @@ -675,6 +675,12 @@ def test_model_not_supports_streaming(mocker): "pricing": {"price": 10, "currency": "USD"}, "params": {}, "version": {"id": "1.0"}, + "attributes": [ + { + "name": "auth_schemes", + "code": '["BEARER_TOKEN", "API_KEY", "BASIC"]' + }, + ] }, Integration, ), @@ -736,14 +742,35 @@ def test_create_model_from_response(payload, expected_model_class): @pytest.mark.parametrize( - "authentication_schema, name, token, client_id, client_secret", + "authentication_schema, name, data", [ - (AuthenticationSchema.BEARER, "test-name", "test-token", None, None), - (AuthenticationSchema.OAUTH, "test-name", None, "test-client-id", "test-client-secret"), + (AuthenticationSchema.BEARER_TOKEN, "test-name", {"token": "test-token"}), + (AuthenticationSchema.API_KEY, "test-name", {"api_key": "test-api-key"}), + (AuthenticationSchema.BASIC, "test-name", {"username": "test-user", "password": "test-pass"}), ], ) -def test_connector_connect(mocker, authentication_schema, name, token, client_id, client_secret): +def test_connector_connect(mocker, authentication_schema, name, data): mocker.patch("aixplain.modules.model.integration.Integration.run", return_value={"id": "test-id"}) + additional_info = { + 'attributes': [ + { + 'name': 'auth_schemes', + 'code': '["BEARER_TOKEN", "API_KEY", "BASIC"]' + }, + { + 'name': 'BEARER_TOKEN-inputs', + 'code': '[{"name": "token"}]' + }, + { + 'name': 'API_KEY-inputs', + 'code': '[{"name": "api_key"}]' + }, + { + 'name': 'BASIC-inputs', + 'code': '[{"name": "username"}, {"name": "password"}]' + } + ] + } connector = Integration( id="connector-id", name="connector-name", @@ -752,11 +779,16 @@ def test_connector_connect(mocker, authentication_schema, name, token, client_id supplier="aiXplain", api_key="api_key", version={"id": "1.0"}, + **additional_info ) + args = build_connector_params(name=name) response = connector.connect( - authentication_schema=authentication_schema, name=name, token=token, client_id=client_id, client_secret=client_secret + authentication_schema=authentication_schema, + args=args, + data=data ) - assert response == {"id": "test-id"} + + assert response["id"] == "test-id" def test_connection_init_with_actions(mocker): @@ -863,14 +895,35 @@ def add(aaa: int, bbb: int) -> int: ) def get_mock(id): - if id == "67eff5c0e05614297caeef98": + additional_info = { + 'attributes': [ + { + 'name': 'auth_schemes', + 'code': '["BEARER_TOKEN", "API_KEY", "BASIC"]' + }, + { + 'name': 'BEARER_TOKEN-inputs', + 'code': '[{"name": "token"}]' + }, + { + 'name': 'API_KEY-inputs', + 'code': '[{"name": "api_key"}]' + }, + { + 'name': 'BASIC-inputs', + 'code': '[{"name": "username"}, {"name": "password"}]' + } + ] + } + if id == "686432941223092cb4294d3f": return Integration( - id="67eff5c0e05614297caeef98", + id="686432941223092cb4294d3f", name="test-name", function=Function.UTILITIES, function_type=FunctionType.INTEGRATION, api_key="api_key", version={"id": "1.0"}, + **additional_info ) elif id == "connection-id": return ConnectionTool( @@ -880,10 +933,11 @@ def get_mock(id): function_type=FunctionType.CONNECTION, api_key="api_key", version={"id": "1.0"}, + **additional_info ) mocker.patch("aixplain.factories.tool_factory.ToolFactory.get", side_effect=get_mock) - tool = ToolFactory.create(integration="67eff5c0e05614297caeef98", name="My Connector 1234", token="slack-token") + tool = ToolFactory.create(integration="686432941223092cb4294d3f", name="My Connector 1234", authentication_schema=AuthenticationSchema.BEARER_TOKEN, data={"token": "slack-token"}) assert isinstance(tool, ConnectionTool) assert tool.id == "connection-id" assert tool.name == "test-name"