diff --git a/aixplain/factories/agent_factory/__init__.py b/aixplain/factories/agent_factory/__init__.py index a00fe3bc..b06dfaf2 100644 --- a/aixplain/factories/agent_factory/__init__.py +++ b/aixplain/factories/agent_factory/__init__.py @@ -163,6 +163,14 @@ def create( raise Exception(error_msg) return agent + @classmethod + def create_from_dict(cls, dict: Dict) -> Agent: + """Create an agent from a dictionary.""" + agent = Agent.from_dict(dict) + agent.validate(raise_exception=True) + agent.url = urljoin(config.BACKEND_URL, f"sdk/agents/{agent.id}/run") + return agent + @classmethod def create_task( cls, name: Text, description: Text, expected_output: Text, dependencies: Optional[List[Text]] = None diff --git a/aixplain/factories/team_agent_factory/__init__.py b/aixplain/factories/team_agent_factory/__init__.py index 77f9fad4..16e9794c 100644 --- a/aixplain/factories/team_agent_factory/__init__.py +++ b/aixplain/factories/team_agent_factory/__init__.py @@ -243,6 +243,14 @@ def create( raise Exception(error_msg) return team_agent + @classmethod + def create_from_dict(cls, dict: Dict) -> TeamAgent: + """Create a team agent from a dictionary.""" + team_agent = TeamAgent.from_dict(dict) + team_agent.validate(raise_exception=True) + team_agent.url = urljoin(config.BACKEND_URL, f"sdk/agent-communities/{team_agent.id}/run") + return team_agent + @classmethod def list(cls) -> Dict: """List all agents available in the platform.""" diff --git a/aixplain/modules/agent/__init__.py b/aixplain/modules/agent/__init__.py index e3ca6f89..f0316b00 100644 --- a/aixplain/modules/agent/__init__.py +++ b/aixplain/modules/agent/__init__.py @@ -385,8 +385,84 @@ def to_dict(self) -> Dict: ] if self.llm is not None else [], + "cost": self.cost, + "api_key": self.api_key, } + @classmethod + def from_dict(cls, data: Dict) -> "Agent": + """Create an Agent instance from a dictionary representation. + + Args: + data: Dictionary containing Agent parameters + + Returns: + Agent instance + """ + from aixplain.factories.agent_factory.utils import build_tool + from aixplain.enums import AssetStatus + from aixplain.modules.agent_task import AgentTask + + # Extract tools from assets using proper tool building + tools = [] + if "assets" in data: + for asset_data in data["assets"]: + try: + tool = build_tool(asset_data) + tools.append(tool) + except Exception as e: + # Log warning but continue processing other tools + import logging + + logging.warning(f"Failed to build tool from asset data: {e}") + + # Extract tasks using from_dict method + tasks = [] + if "tasks" in data: + for task_data in data["tasks"]: + tasks.append(AgentTask.from_dict(task_data)) + + # Extract LLM from tools section (main LLM info) + llm = None + if "tools" in data and data["tools"]: + llm_tool = next((tool for tool in data["tools"] if tool.get("type") == "llm"), None) + if llm_tool and llm_tool.get("parameters"): + # Reconstruct LLM from parameters if available + from aixplain.factories.model_factory import ModelFactory + + try: + llm = ModelFactory.get(data.get("llmId", "6646261c6eb563165658bbb1")) + if llm_tool.get("parameters"): + # Apply stored parameters to LLM + llm.set_parameters(llm_tool["parameters"]) + except Exception: + # If LLM loading fails, llm remains None and llm_id will be used + pass + + # Extract status + status = AssetStatus.DRAFT + if "status" in data: + if isinstance(data["status"], str): + status = AssetStatus(data["status"]) + else: + status = data["status"] + + return cls( + id=data["id"], + name=data["name"], + description=data["description"], + instructions=data.get("role"), + tools=tools, + llm_id=data.get("llmId", "6646261c6eb563165658bbb1"), + llm=llm, + api_key=data.get("api_key"), + supplier=data.get("supplier", "aiXplain"), + version=data.get("version"), + cost=data.get("cost"), + status=status, + tasks=tasks, + ) + def delete(self) -> None: """Delete Agent service""" try: diff --git a/aixplain/modules/agent/agent_task.py b/aixplain/modules/agent/agent_task.py index 8d6acd2b..593a0d0f 100644 --- a/aixplain/modules/agent/agent_task.py +++ b/aixplain/modules/agent/agent_task.py @@ -27,3 +27,20 @@ def to_dict(self): if isinstance(dependency, AgentTask): agent_task_dict["dependencies"][i] = dependency.name return agent_task_dict + + @classmethod + def from_dict(cls, data: dict) -> "AgentTask": + """Create an AgentTask instance from a dictionary representation. + + Args: + data: Dictionary containing AgentTask parameters + + Returns: + AgentTask instance + """ + return cls( + name=data["name"], + description=data["description"], + expected_output=data["expectedOutput"], + dependencies=data.get("dependencies", None), + ) diff --git a/aixplain/modules/team_agent/__init__.py b/aixplain/modules/team_agent/__init__.py index fa1aca9e..078285d2 100644 --- a/aixplain/modules/team_agent/__init__.py +++ b/aixplain/modules/team_agent/__init__.py @@ -335,6 +335,29 @@ def delete(self) -> None: logging.error(message) raise Exception(f"{message}") + def _serialize_agent(self, agent, idx: int) -> Dict: + """Serialize an agent for the to_dict method.""" + base_dict = {"assetId": agent.id, "number": idx, "type": "AGENT", "label": "AGENT"} + + # Try to get additional data from agent's to_dict method + try: + if hasattr(agent, "to_dict") and callable(getattr(agent, "to_dict")): + agent_dict = agent.to_dict() + # Ensure it's actually a dictionary and not a Mock or other object + if isinstance(agent_dict, dict) and hasattr(agent_dict, "items"): + try: + # Add all fields except 'id' to avoid duplication with 'assetId' + additional_data = {k: v for k, v in agent_dict.items() if k not in ["id"]} + base_dict.update(additional_data) + except (TypeError, AttributeError): + # If items() doesn't work or iteration fails, skip the additional data + pass + except Exception: + # If anything goes wrong, just use the base dictionary + pass + + return base_dict + def to_dict(self) -> Dict: if self.use_mentalist: planner_id = self.mentalist_llm.id if self.mentalist_llm else self.llm_id @@ -343,9 +366,7 @@ def to_dict(self) -> Dict: return { "id": self.id, "name": self.name, - "agents": [ - {"assetId": agent.id, "number": idx, "type": "AGENT", "label": "AGENT"} for idx, agent in enumerate(self.agents) - ], + "agents": [self._serialize_agent(agent, idx) for idx, agent in enumerate(self.agents)], "links": [], "description": self.description, "llmId": self.llm.id if self.llm else self.llm_id, @@ -359,6 +380,107 @@ def to_dict(self) -> Dict: "role": self.instructions, } + @classmethod + def from_dict(cls, data: Dict) -> "TeamAgent": + """Create a TeamAgent instance from a dictionary representation. + + Args: + data: Dictionary containing TeamAgent parameters + + Returns: + TeamAgent instance + """ + from aixplain.factories.agent_factory import AgentFactory + from aixplain.factories.model_factory import ModelFactory + from aixplain.enums import AssetStatus + from aixplain.modules.team_agent import Inspector, InspectorTarget + + # Extract agents from agents list using proper agent loading + agents = [] + if "agents" in data: + for agent_data in data["agents"]: + if "assetId" in agent_data: + try: + # Load agent using AgentFactory + agent = AgentFactory.get(agent_data["assetId"]) + agents.append(agent) + except Exception as e: + # Log warning but continue processing other agents + import logging + + logging.warning(f"Failed to load agent {agent_data['assetId']}: {e}") + + # Extract inspectors using proper model validation + inspectors = [] + if "inspectors" in data: + for inspector_data in data["inspectors"]: + try: + if hasattr(Inspector, "model_validate"): + inspectors.append(Inspector.model_validate(inspector_data)) + else: + inspectors.append(Inspector(**inspector_data)) + except Exception as e: + import logging + + logging.warning(f"Failed to create inspector from data: {e}") + + # Extract inspector targets + inspector_targets = [InspectorTarget.STEPS] # default + if "inspectorTargets" in data: + inspector_targets = [InspectorTarget(target) for target in data["inspectorTargets"]] + + # Extract status + status = AssetStatus.DRAFT + if "status" in data: + if isinstance(data["status"], str): + status = AssetStatus(data["status"]) + else: + status = data["status"] + + # Extract LLM instances using proper model loading + llm = None + supervisor_llm = None + mentalist_llm = None + + try: + if "llmId" in data: + llm = ModelFactory.get(data["llmId"]) + except Exception: + pass # llm remains None, will use llm_id + + try: + if "supervisorId" in data and data["supervisorId"] != data.get("llmId"): + supervisor_llm = ModelFactory.get(data["supervisorId"]) + except Exception: + pass # supervisor_llm remains None + + try: + if "plannerId" in data and data["plannerId"]: + mentalist_llm = ModelFactory.get(data["plannerId"]) + except Exception: + pass # mentalist_llm remains None + + # Determine if mentalist is used + use_mentalist = data.get("plannerId") is not None + + return cls( + id=data["id"], + name=data["name"], + agents=agents, + description=data.get("description", ""), + llm_id=data.get("llmId", "6646261c6eb563165658bbb1"), + llm=llm, + supervisor_llm=supervisor_llm, + mentalist_llm=mentalist_llm, + supplier=data.get("supplier", "aiXplain"), + version=data.get("version"), + use_mentalist=use_mentalist, + status=status, + instructions=data.get("role"), + inspectors=inspectors, + inspector_targets=inspector_targets, + ) + def _validate(self) -> None: from aixplain.utils.llm_utils import get_llm_instance @@ -429,4 +551,3 @@ def update(self) -> None: def __repr__(self): return f"TeamAgent: {self.name} (id={self.id})" - diff --git a/tests/unit/agent/agent_test.py b/tests/unit/agent/agent_test.py index a93c35a8..8546b998 100644 --- a/tests/unit/agent/agent_test.py +++ b/tests/unit/agent/agent_test.py @@ -1184,3 +1184,184 @@ def test_create_agent_with_duplicate_tool_names(mocker): assert "Agent Creation Error - Duplicate tool names found: Test Model. Make sure all tool names are unique." in str( exc_info.value ) + + +def test_agent_task_serialization(): + """Test AgentTask to_dict/from_dict round-trip serialization.""" + from aixplain.modules.agent.agent_task import AgentTask + + # Create test task + task = AgentTask( + name="Test Task", + description="A test task for validation", + expected_output="Expected output description", + dependencies=["task1", "task2"], + ) + + # Test to_dict + task_dict = task.to_dict() + expected_keys = {"name", "description", "expectedOutput", "dependencies"} + assert set(task_dict.keys()) == expected_keys + assert task_dict["name"] == "Test Task" + assert task_dict["description"] == "A test task for validation" + assert task_dict["expectedOutput"] == "Expected output description" + assert task_dict["dependencies"] == ["task1", "task2"] + + # Test from_dict + reconstructed_task = AgentTask.from_dict(task_dict) + + # Verify round-trip + assert task.name == reconstructed_task.name + assert task.description == reconstructed_task.description + assert task.expected_output == reconstructed_task.expected_output + assert task.dependencies == reconstructed_task.dependencies + + +def test_agent_task_serialization_with_task_dependencies(): + """Test AgentTask serialization when dependencies are AgentTask objects.""" + from aixplain.modules.agent.agent_task import AgentTask + + # Create dependency tasks + dep_task1 = AgentTask(name="Dependency Task 1", description="First dependency", expected_output="Dep output 1") + dep_task2 = AgentTask(name="Dependency Task 2", description="Second dependency", expected_output="Dep output 2") + + # Create main task with AgentTask dependencies + main_task = AgentTask( + name="Main Task", + description="Main task with AgentTask dependencies", + expected_output="Main output", + dependencies=[dep_task1, dep_task2, "string_dependency"], + ) + + # Test to_dict - should convert AgentTask dependencies to names + task_dict = main_task.to_dict() + assert task_dict["dependencies"] == ["Dependency Task 1", "Dependency Task 2", "string_dependency"] + + # Test from_dict - dependencies will be strings + reconstructed_task = AgentTask.from_dict(task_dict) + assert reconstructed_task.dependencies == ["Dependency Task 1", "Dependency Task 2", "string_dependency"] + + +def test_agent_serialization_completeness(): + """Test that Agent to_dict includes all necessary fields.""" + from aixplain.modules.agent.agent_task import AgentTask + + # Create test tasks + task1 = AgentTask(name="Task 1", description="First task", expected_output="Output 1") + task2 = AgentTask(name="Task 2", description="Second task", expected_output="Output 2", dependencies=["Task 1"]) + + # Create test agent with comprehensive data + agent = Agent( + id="test-agent-123", + name="Test Agent", + description="A test agent for validation", + instructions="You are a helpful test agent", + tools=[], # Empty for simplicity + llm_id="6646261c6eb563165658bbb1", + api_key="test-api-key", + supplier="aixplain", + version="1.0.0", + cost={"input": 0.01, "output": 0.02}, + status=AssetStatus.DRAFT, + tasks=[task1, task2], + ) + + # Test to_dict includes all expected fields + agent_dict = agent.to_dict() + + required_fields = { + "id", + "name", + "description", + "role", + "assets", + "supplier", + "version", + "llmId", + "status", + "tasks", + "tools", + "cost", + "api_key", + } + assert set(agent_dict.keys()) == required_fields + + # Verify field values + assert agent_dict["id"] == "test-agent-123" + assert agent_dict["name"] == "Test Agent" + assert agent_dict["description"] == "A test agent for validation" + assert agent_dict["role"] == "You are a helpful test agent" + assert agent_dict["llmId"] == "6646261c6eb563165658bbb1" + assert agent_dict["api_key"] == "test-api-key" + assert agent_dict["supplier"] == "aixplain" + assert agent_dict["version"] == "1.0.0" + assert agent_dict["cost"] == {"input": 0.01, "output": 0.02} + assert agent_dict["status"] == "draft" + assert isinstance(agent_dict["assets"], list) + assert isinstance(agent_dict["tasks"], list) + assert len(agent_dict["tasks"]) == 2 + + # Verify task serialization + task_dict = agent_dict["tasks"][0] + assert task_dict["name"] == "Task 1" + assert task_dict["description"] == "First task" + assert task_dict["expectedOutput"] == "Output 1" + + +def test_agent_serialization_with_llm(): + """Test Agent to_dict when LLM instance is provided.""" + from unittest.mock import Mock + + # Mock LLM with parameters + mock_llm = Mock() + mock_llm.id = "custom-llm-id" + mock_parameters = Mock() + mock_parameters.to_list.return_value = [{"name": "temperature", "value": 0.7}] + mock_llm.get_parameters.return_value = mock_parameters + + agent = Agent(id="test-agent", name="Test Agent", description="Test description", llm=mock_llm, llm_id="fallback-llm-id") + + agent_dict = agent.to_dict() + + # Should use LLM instance ID instead of llm_id + assert agent_dict["llmId"] == "custom-llm-id" + + # Should include LLM parameters in tools section + assert len(agent_dict["tools"]) == 1 + llm_tool = agent_dict["tools"][0] + assert llm_tool["type"] == "llm" + assert llm_tool["description"] == "main" + assert llm_tool["parameters"] == [{"name": "temperature", "value": 0.7}] + + +def test_agent_serialization_role_fallback(): + """Test Agent to_dict role field fallback behavior.""" + # Test with instructions provided + agent_with_instructions = Agent( + id="test1", name="Test Agent 1", description="Test description", instructions="Custom instructions" + ) + + dict1 = agent_with_instructions.to_dict() + assert dict1["role"] == "Custom instructions" + + # Test without instructions (should fall back to description) + agent_without_instructions = Agent(id="test2", name="Test Agent 2", description="Test description") + + dict2 = agent_without_instructions.to_dict() + assert dict2["role"] == "Test description" + + +@pytest.mark.parametrize( + "status_input,expected_output", + [ + (AssetStatus.DRAFT, "draft"), + (AssetStatus.ONBOARDED, "onboarded"), + (AssetStatus.COMPLETED, "completed"), + ], +) +def test_agent_serialization_status_enum(status_input, expected_output): + """Test Agent to_dict properly serializes AssetStatus enum.""" + agent = Agent(id="test-agent", name="Test Agent", description="Test description", status=status_input) + + agent_dict = agent.to_dict() + assert agent_dict["status"] == expected_output diff --git a/tests/unit/team_agent/team_agent_test.py b/tests/unit/team_agent/team_agent_test.py index 5a154d0b..05f02fc3 100644 --- a/tests/unit/team_agent/team_agent_test.py +++ b/tests/unit/team_agent/team_agent_test.py @@ -388,3 +388,176 @@ def test_deploy_team_agent(): # Verify that status was updated and update was called assert team_agent.status == AssetStatus.ONBOARDED team_agent.update.assert_called_once() + + +def test_team_agent_serialization_completeness(): + """Test that TeamAgent to_dict includes all necessary fields.""" + from unittest.mock import Mock + + # Create mock agents + mock_agent1 = Mock() + mock_agent1.id = "agent-1" + mock_agent1.name = "Agent 1" + + mock_agent2 = Mock() + mock_agent2.id = "agent-2" + mock_agent2.name = "Agent 2" + + # Create mock inspectors + mock_inspector = Mock() + mock_inspector.model_dump.return_value = {"type": "test_inspector", "config": {"threshold": 0.8}} + + # Create test team agent with comprehensive data + team_agent = TeamAgent( + id="test-team-123", + name="Test Team", + agents=[mock_agent1, mock_agent2], + description="A test team agent", + llm_id="6646261c6eb563165658bbb1", + supervisor_llm=None, + mentalist_llm=None, + supplier="aixplain", + version="1.0.0", + use_mentalist=False, + status=AssetStatus.DRAFT, + instructions="You are a helpful team agent", + inspectors=[mock_inspector], + inspector_targets=[InspectorTarget.STEPS], + ) + + # Test to_dict includes all expected fields + team_dict = team_agent.to_dict() + + required_fields = { + "id", + "name", + "agents", + "links", + "description", + "llmId", + "supervisorId", + "plannerId", + "inspectors", + "inspectorTargets", + "supplier", + "version", + "status", + "role", + } + assert set(team_dict.keys()) == required_fields + + # Verify field values + assert team_dict["id"] == "test-team-123" + assert team_dict["name"] == "Test Team" + assert team_dict["description"] == "A test team agent" + assert team_dict["role"] == "You are a helpful team agent" + assert team_dict["llmId"] == "6646261c6eb563165658bbb1" + assert team_dict["supplier"] == "aixplain" + assert team_dict["version"] == "1.0.0" + assert team_dict["status"] == "draft" + assert team_dict["links"] == [] + assert team_dict["plannerId"] is None # use_mentalist=False + + # Verify agents serialization + assert isinstance(team_dict["agents"], list) + assert len(team_dict["agents"]) == 2 + agent_dict = team_dict["agents"][0] + assert agent_dict["assetId"] == "agent-1" + assert agent_dict["number"] == 0 + assert agent_dict["type"] == "AGENT" + assert agent_dict["label"] == "AGENT" + + # Verify inspectors serialization + assert isinstance(team_dict["inspectors"], list) + assert len(team_dict["inspectors"]) == 1 + assert team_dict["inspectors"][0] == {"type": "test_inspector", "config": {"threshold": 0.8}} + + # Verify inspector targets + assert team_dict["inspectorTargets"] == ["steps"] + + +def test_team_agent_serialization_with_llms(): + """Test TeamAgent to_dict when LLM instances are provided.""" + from unittest.mock import Mock + + # Mock different LLMs + mock_llm = Mock() + mock_llm.id = "main-llm-id" + + mock_supervisor = Mock() + mock_supervisor.id = "supervisor-llm-id" + + mock_mentalist = Mock() + mock_mentalist.id = "mentalist-llm-id" + + team_agent = TeamAgent( + id="test-team", + name="Test Team", + agents=[], + description="Test team with LLMs", + llm_id="fallback-llm-id", + llm=mock_llm, + supervisor_llm=mock_supervisor, + mentalist_llm=mock_mentalist, + use_mentalist=True, + ) + + team_dict = team_agent.to_dict() + + # Should use LLM instance IDs + assert team_dict["llmId"] == "main-llm-id" + assert team_dict["supervisorId"] == "supervisor-llm-id" + assert team_dict["plannerId"] == "mentalist-llm-id" + + +def test_team_agent_serialization_mentalist_logic(): + """Test TeamAgent to_dict plannerId logic based on use_mentalist and mentalist_llm.""" + + # Case 1: use_mentalist=True but no mentalist_llm + team_agent1 = TeamAgent(id="team1", name="Team 1", agents=[], use_mentalist=True, llm_id="main-llm") + + dict1 = team_agent1.to_dict() + assert dict1["plannerId"] == "main-llm" # Falls back to main LLM + + # Case 2: use_mentalist=False + team_agent2 = TeamAgent(id="team2", name="Team 2", agents=[], use_mentalist=False, llm_id="main-llm") + + dict2 = team_agent2.to_dict() + assert dict2["plannerId"] is None + + +@pytest.mark.parametrize( + "status_input,expected_output", + [ + (AssetStatus.DRAFT, "draft"), + (AssetStatus.ONBOARDED, "onboarded"), + (AssetStatus.COMPLETED, "completed"), + ], +) +def test_team_agent_serialization_status_enum(status_input, expected_output): + """Test TeamAgent to_dict properly serializes AssetStatus enum.""" + team_agent = TeamAgent(id="test-team", name="Test Team", agents=[], description="Test description", status=status_input) + + team_dict = team_agent.to_dict() + assert team_dict["status"] == expected_output + + +def test_team_agent_serialization_supervisor_fallback(): + """Test TeamAgent to_dict supervisorId fallback behavior.""" + + # Case 1: No supervisor_llm provided + team_agent1 = TeamAgent(id="team1", name="Team 1", agents=[], llm_id="main-llm-id") + + dict1 = team_agent1.to_dict() + assert dict1["supervisorId"] == "main-llm-id" # Falls back to main LLM + + # Case 2: supervisor_llm provided + from unittest.mock import Mock + + mock_supervisor = Mock() + mock_supervisor.id = "supervisor-llm-id" + + team_agent2 = TeamAgent(id="team2", name="Team 2", agents=[], llm_id="main-llm-id", supervisor_llm=mock_supervisor) + + dict2 = team_agent2.to_dict() + assert dict2["supervisorId"] == "supervisor-llm-id"