Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
418 changes: 418 additions & 0 deletions examples/07_agents_bedrock.ipynb

Large diffs are not rendered by default.

7 changes: 5 additions & 2 deletions libs/core/llmstudio_core/agents/bedrock/data_models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import uuid
from typing import Optional

from llmstudio_core.agents.data_models import (
AgentBase,
CreateAgentRequest,
RunAgentRequest,
RunBase,
)
from pydantic import Field


class BedrockAgent(AgentBase):
Expand All @@ -25,5 +29,4 @@ class BedrockCreateAgentRequest(CreateAgentRequest):


class BedrockRunAgentRequest(RunAgentRequest):
session_id: str
agent_alias_id: str
session_id: Optional[str] = Field(default_factory=lambda: str(uuid.uuid4()))
96 changes: 57 additions & 39 deletions libs/core/llmstudio_core/agents/bedrock/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Message,
ResultBase,
TextContent,
TextObject,
)
from llmstudio_core.agents.manager import AgentManager, agent_manager
from llmstudio_core.exceptions import AgentError
Expand Down Expand Up @@ -59,17 +60,19 @@ def _validate_run_request(self, request):
return BedrockRunAgentRequest(**request)

def _validate_result_request(self, request):
return RetrieveResultRequest(**request)
if isinstance(request, BedrockRun):
return request
return BedrockRun(**request)

def create_agent(self, **kwargs) -> BedrockAgent:
def create_agent(self, params: dict = None) -> BedrockAgent:
"""
This method validates the input parameters, creates a new agent using the client,
waits for the agent to reach the 'NOT_PREPARED' status, adds tools to the agent,
prepares the agent for use, creates an alias for the agent, and waits for the alias
to be prepared.

Args:
**kwargs: Agent creation parameters.
params: Agent creation parameters.

Returns:
BedrockAgent: An instance of the created BedrockAgent.
Expand All @@ -80,11 +83,7 @@ def create_agent(self, **kwargs) -> BedrockAgent:
"""

try:
agent_request = self._validate_create_request(
dict(
**kwargs,
)
)
agent_request = self._validate_create_request(params)

except ValidationError as e:
raise AgentError(str(e))
Expand Down Expand Up @@ -153,7 +152,7 @@ def create_agent(self, **kwargs) -> BedrockAgent:
agentAliasStatus = response["agentAlias"]["agentAliasStatus"]

return BedrockAgent(
id=agentId,
agent_id=agentId,
created_at=int(bedrock_create["agent"]["createdAt"].timestamp()),
name=bedrock_create["agent"]["agentName"],
description=bedrock_create.get("agent", {}).get("description", None),
Expand All @@ -166,7 +165,7 @@ def create_agent(self, **kwargs) -> BedrockAgent:
agent_alias_id=agentAliasId,
)

def run_agent(self, **kwargs) -> BedrockRun:
def run_agent(self, params: dict = None) -> BedrockRun:
"""
Runs the agent with the provided keyword arguments.

Expand All @@ -181,17 +180,34 @@ def run_agent(self, **kwargs) -> BedrockRun:
"""

try:
run_request = self._validate_run_request(
dict(
**kwargs,
)
)
run_request = self._validate_run_request(params)
except ValidationError as e:
raise AgentError(str(e))

sessionState = {"files": []}
sessionState = {"files": [], "conversationHistory": {"messages": []}}

if isinstance(run_request.messages, Message):
last_message = run_request.messages
elif isinstance(run_request.messages, list) and run_request.messages:
last_message = run_request.messages.pop()

for message in run_request.messages:
bedrock_message = {"role": message.role, "content": []}

# Extract text content
if isinstance(message.content, str):
bedrock_message["content"].append({"text": message.content})

elif isinstance(message.content, list):
for item in message.content:
if isinstance(item, TextContent):
bedrock_message["content"].append({"text": item.text.value})

for attachment in run_request.message.attachments:
sessionState["conversationHistory"]["messages"].append(bedrock_message)
else:
raise AgentError("No valid messages found in the run request")

for attachment in last_message.attachments:
if any(tool.type == "code_interpreter" for tool in attachment.tools):
sessionState["files"].append(
{
Expand All @@ -207,33 +223,33 @@ def run_agent(self, **kwargs) -> BedrockRun:
}
)

if isinstance(run_request.message.content, str):
input_text = run_request.message.content # Use it directly if it's a string
elif isinstance(run_request.message.content, list):
if isinstance(last_message.content, str):
input_text = last_message.content # Use it directly if it's a string
elif isinstance(last_message.content, list):
input_text = " ".join(
item.text
for item in run_request.message.content
for item in last_message.content
if isinstance(item, TextContent)
)
else:
input_text = "" # Default to an empty string if content is not valid

invoke_request = self._runtime_client.invoke_agent(
agentId=run_request.agent_id,
agentAliasId=run_request.agent_alias_id,
agentId=run_request.agent.agent_id,
agentAliasId=run_request.agent.agent_alias_id,
sessionId=run_request.session_id,
inputText=input_text,
sessionState=sessionState,
)

return BedrockRun(
agent_id=run_request.agent_id,
agent_id=run_request.agent.agent_id,
status="completed",
session_id=run_request.session_id,
response=invoke_request,
)

def retrieve_result(self, **kwargs) -> ResultBase:
def retrieve_result(self, run: BedrockRun) -> ResultBase:
"""
Retrieve the result based on the provided keyword arguments.
This method validates the result request and processes the event stream to
Expand All @@ -247,23 +263,23 @@ def retrieve_result(self, **kwargs) -> ResultBase:
"""

try:
result_request = self._validate_result_request(
dict(
**kwargs,
)
)
run = self._validate_result_request(run)

except ValidationError as e:
raise AgentError(str(e))

content = []
attachments = []
event_stream = result_request.run.response.get("completion")
event_stream = run.response.get("completion")
for event in event_stream:
if "chunk" in event:
chunk = event["chunk"]
if "bytes" in chunk:
content.append(TextContent(text=chunk["bytes"].decode("utf-8")))
content.append(
TextContent(
text=TextObject(value=chunk["bytes"].decode("utf-8"))
)
)

if "files" in event:
files = event["files"]["files"]
Expand All @@ -287,11 +303,13 @@ def retrieve_result(self, **kwargs) -> ResultBase:
)
)

message = Message(
thread_id=result_request.run.session_id,
role="assistant",
content=content,
attachments=attachments,
)
messages = [
Message(
thread_id=run.session_id,
role="assistant",
content=content,
attachments=attachments,
)
]

return ResultBase(message=message)
return ResultBase(messages=messages)
4 changes: 2 additions & 2 deletions libs/core/llmstudio_core/agents/data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,5 +118,5 @@ class CreateAgentRequest(BaseModel):


class RunAgentRequest(BaseModel):
agent_id: str
message: Union[Message, List[Message]]
agent: AgentBase
messages: Union[Message, List[Message]]