Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion src/bedrock_agentcore/memory/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
from typing import Any, Dict

from .DictWrapper import DictWrapper

from .filters import (
StringValue,
MetadataValue,
MetadataKey,
LeftExpression,
OperatorType,
RightExpression,
EventMetadataFilter,
)

class ActorSummary(DictWrapper):
"""A class representing an actor summary."""
Expand Down Expand Up @@ -75,3 +83,20 @@ def __init__(self, session_summary: Dict[str, Any]):
session_summary: Dictionary containing session summary data.
"""
super().__init__(session_summary)

__all__ = [
"DictWrapper",
"ActorSummary",
"Branch",
"Event",
"EventMessage",
"MemoryRecord",
"SessionSummary",
"StringValue",
"MetadataValue",
"MetadataKey",
"LeftExpression",
"OperatorType",
"RightExpression",
"EventMetadataFilter",
]
118 changes: 118 additions & 0 deletions src/bedrock_agentcore/memory/models/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from enum import Enum
from typing import Optional, TypedDict, Union, NotRequired

class StringValue(TypedDict):
"""Value associated with the `eventMetadata` key."""
stringValue: str

@staticmethod
def build(value: str) -> 'StringValue':
return {
"stringValue": value
}

MetadataValue = Union[StringValue]
"""
Union type representing metadata values.

Variants:
- StringValue: {"stringValue": str} - String metadata value
"""

MetadataKey = Union[str]
"""
Union type representing metadata key.
"""

class LeftExpression(TypedDict):
"""
Left operand of the event metadata filter expression.
"""
metadataKey: MetadataKey

@staticmethod
def build(key: str) -> 'LeftExpression':
"""Builds the `metadataKey` for `LeftExpression`"""
return {
"metadataKey": key
}

class OperatorType(Enum):
"""
Operator applied to the event metadata filter expression.

Currently supports:
- `EQUALS_TO`
- `EXISTS`
- `NOT_EXISTS`
"""
EQUALS_TO = "EQUALS_TO"
EXISTS = "EXISTS"
NOT_EXISTS = "NOT_EXISTS"

class RightExpression(TypedDict):
"""
Right operand of the event metadata filter expression.

Variants:
- StringValue: {"metadataValue": {"stringValue": str}}
"""
metadataValue: MetadataValue

@staticmethod
def build(value: str) -> 'RightExpression':
"""Builds the `RightExpression` for `stringValue` type"""
return {"metadataValue": StringValue.build(value)}

class EventMetadataFilter(TypedDict):
"""
Filter expression for retrieving events based on metadata associated with an event.

Args:
left: `LeftExpression` of the event metadata filter expression.
operator: `OperatorType` applied to the event metadata filter expression.
right: Optional `RightExpression` of the event metadata filter expression.
"""
left: LeftExpression
operator: OperatorType
right: NotRequired[RightExpression]

def build_expression(left_operand: LeftExpression, operator: OperatorType, right_operand: Optional[RightExpression] = None) -> 'EventMetadataFilter':
"""
This method builds the required event metadata filter expression into the `EventMetadataFilterExpression` type when querying listEvents.

Args:
left_operand: Left operand of the event metadata filter expression
operator: Operator applied to the event metadata filter expression
right_operand: Optional right_operand of the event metadata filter expression.

Example:
```
left_operand = LeftExpression.build_key(key='location')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be possible to have the syntax like

LeftExpression(key='location')

and then we infer the type of from inspecting the type? Like we use the stringValue by checking that the key argument` is a string?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this is a TypedDict - we cannot - as it's supposed to be a type hint for dicts with specific structures - it doesn't provide runtime type conversion or inference.

For inferring the type - if needed, we can create a class method - which then invokes the respective methods.

Example:

class RightExpression(TypedDict):
    metadataValue: MetadataValue

    @classmethod
    def create(cls, value: Union[str, float, bool]) -> 'RightExpression':

        if isinstance(value, str):
            return cls(metadataValue=StringValue.create(value))
        elif isinstance(value, (int, float)):
            return cls(metadataValue=NumberValue.create(float(value)))
        else:
            raise ValueError(f"Unsupported value type: {type(value)}")

(Where we'd have other new TypedDict classes when we introduce new types (e.g: NumberValue))

If at any point, we'd want to support other functionalities for each of these types or for LeftExpression or RightExpression as an overall feature - the customer would still have to invoke the respective methods manually without any inference - where those methods would still end up being @staticmethod.

I think having it as above is more intuitive and allows the customer to understand what key type needs to be built.

operator = OperatorType.EQUALS_TO
right_operand = RightExpression.build_string_value(value='NYC')
```

#### Response Object:
```
{
'left': {
'metadataKey': 'location'
},
'operator': 'EQUALS_TO',
'right': {
'metadataValue': {
'stringValue': 'NYC'
}
}
}
```
"""
filter = {
'left': left_operand,
'operator': operator.value
}

if right_operand:
filter['right'] = right_operand
return filter
82 changes: 80 additions & 2 deletions src/bedrock_agentcore/memory/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
EventMessage,
MemoryRecord,
SessionSummary,
MetadataValue,
EventMetadataFilter
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -246,6 +248,7 @@ def process_turn_with_llm(
user_input: str,
llm_callback: Callable[[str, List[Dict[str, Any]]], str],
retrieval_config: Optional[Dict[str, RetrievalConfig]],
metadata: Optional[Dict[str, MetadataValue]] = None,
event_timestamp: Optional[datetime] = None,
) -> Tuple[List[Dict[str, Any]], str, Dict[str, Any]]:
r"""Complete conversation turn with LLM callback integration.
Expand All @@ -263,6 +266,7 @@ def process_turn_with_llm(
retrieval_config: Optional dictionary mapping namespaces to RetrievalConfig objects.
Each namespace can contain template variables like {actorId}, {sessionId},
{memoryStrategyId} that will be resolved at runtime.
metadata: Optional custom key-value metadata to attach to an event.
event_timestamp: Optional timestamp for the event

Returns:
Expand Down Expand Up @@ -340,6 +344,7 @@ def my_llm(user_input: str, memories: List[Dict]) -> str:
ConversationalMessage(user_input, MessageRole.USER),
ConversationalMessage(agent_response, MessageRole.ASSISTANT),
],
metadata=metadata,
event_timestamp=event_timestamp,
)

Expand All @@ -352,6 +357,7 @@ def add_turns(
session_id: str,
messages: List[Union[ConversationalMessage, BlobMessage]],
branch: Optional[Dict[str, str]] = None,
metadata: Optional[Dict[str, MetadataValue]] = None,
event_timestamp: Optional[datetime] = None,
) -> Event:
"""Adds conversational turns or blob objects to short-term memory.
Expand All @@ -365,21 +371,31 @@ def add_turns(
- ConversationalMessage objects for conversational messages
- BlobMessage objects for blob data
branch: Optional branch info
metadata: Optional custom key-value metadata to attach to an event.
event_timestamp: Optional timestamp for the event

Returns:
Created event

Example:
```
manager.add_turns(
actor_id="user-123",
session_id="session-456",
messages=[
ConversationalMessage("Hello", USER),
BlobMessage({"file_data": "base64_content"}),
ConversationalMessage("How can I help?", ASSISTANT)
],
metadata=[
{
'location': {
'stringValue': 'NYC'
}
}
]
)
```
"""
logger.info(" -> Storing %d messages in short-term memory...", len(messages))

Expand Down Expand Up @@ -412,6 +428,10 @@ def add_turns(

if branch:
params["branch"] = branch

if metadata:
params["metadata"] = metadata

try:
response = self._data_plane_client.create_event(**params)
logger.info(" ✅ Turn stored successfully with Event ID: %s", response.get("eventId"))
Expand All @@ -427,6 +447,7 @@ def fork_conversation(
root_event_id: str,
branch_name: str,
messages: List[Union[ConversationalMessage, BlobMessage]],
metadata: Optional[Dict[str, MetadataValue]] = None,
event_timestamp: Optional[datetime] = None,
) -> Dict[str, Any]:
"""Fork a conversation from a specific event to create a new branch."""
Expand All @@ -439,6 +460,7 @@ def fork_conversation(
messages=messages,
event_timestamp=event_timestamp,
branch=branch,
metadata=metadata,
)

logger.info("Created branch '%s' from event %s", branch_name, root_event_id)
Expand All @@ -454,6 +476,7 @@ def list_events(
session_id: str,
branch_name: Optional[str] = None,
include_parent_branches: bool = False,
eventMetadata: Optional[List[EventMetadataFilter]] = None,
max_results: int = 100,
include_payload: bool = True,
) -> List[Event]:
Expand Down Expand Up @@ -482,6 +505,49 @@ def list_events(

# Get events from a specific branch
branch_events = client.list_events(actor_id, session_id, branch_name="test-branch")

#### Get events with event metadata filter
```
filtered_events_with_metadata = client.list_events(
actor_id=actor_id,
session_id=session_id,
eventMetadata=[
{
'left': {
'metadataKey': 'location'
},
'operator': 'EQUALS_TO',
'right': {
'metadataValue': {
'stringValue': 'NYC'
}
}
}
]
)
```

#### Get events with event metadata filter + specific branch filter
```
branch_with_metadata_filtered_events = client.list_events(
actor_id=actor_id,
session_id=session_id,
branch_name="test-branch",
eventMetadata=[
{
'left': {
'metadataKey': 'location'
},
'operator': 'EQUALS_TO',
'right': {
'metadataValue': {
'stringValue': 'NYC'
}
}
}
]
)
```
"""
try:
all_events: List[Event] = []
Expand Down Expand Up @@ -509,6 +575,12 @@ def list_events(
"branch": {"name": branch_name, "includeParentBranches": include_parent_branches}
}

# Add eventMetadata filter if specified
if eventMetadata:
params["filter"] = {
"eventMetadata": eventMetadata
}

response = self._data_plane_client.list_events(**params)

events = response.get("events", [])
Expand Down Expand Up @@ -888,28 +960,31 @@ def add_turns(
self,
messages: List[Union[ConversationalMessage, BlobMessage]],
branch: Optional[Dict[str, str]] = None,
metadata: Optional[Dict[str, MetadataValue]] = None,
event_timestamp: Optional[datetime] = None,
) -> Event:
"""Delegates to manager.add_turns."""
return self._manager.add_turns(self._actor_id, self._session_id, messages, branch, event_timestamp)
return self._manager.add_turns(self._actor_id, self._session_id, messages, branch, metadata, event_timestamp)

def fork_conversation(
self,
messages: List[Union[ConversationalMessage, BlobMessage]],
root_event_id: str,
branch_name: str,
metadata: Optional[Dict[str, MetadataValue]] = None,
event_timestamp: Optional[datetime] = None,
) -> Event:
"""Delegates to manager.fork_conversation."""
return self._manager.fork_conversation(
self._actor_id, self._session_id, root_event_id, branch_name, messages, event_timestamp
self._actor_id, self._session_id, root_event_id, branch_name, messages, metadata, event_timestamp
)

def process_turn_with_llm(
self,
user_input: str,
llm_callback: Callable[[str, List[Dict[str, Any]]], str],
retrieval_config: Optional[Dict[str, RetrievalConfig]],
metadata: Optional[Dict[str, MetadataValue]] = None,
event_timestamp: Optional[datetime] = None,
) -> Tuple[List[Dict[str, Any]], str, Dict[str, Any]]:
"""Delegates to manager.process_turn_with_llm."""
Expand All @@ -919,6 +994,7 @@ def process_turn_with_llm(
user_input,
llm_callback,
retrieval_config,
metadata,
event_timestamp,
)

Expand Down Expand Up @@ -975,6 +1051,7 @@ def list_events(
self,
branch_name: Optional[str] = None,
include_parent_branches: bool = False,
eventMetadata: Optional[List[EventMetadataFilter]] = None,
max_results: int = 100,
include_payload: bool = True,
) -> List[Event]:
Expand All @@ -984,6 +1061,7 @@ def list_events(
session_id=self._session_id,
branch_name=branch_name,
include_parent_branches=include_parent_branches,
eventMetadata=eventMetadata,
include_payload=include_payload,
max_results=max_results,
)
Expand Down
Loading