Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import json
import logging
from datetime import datetime, timezone
import threading
from datetime import datetime, timezone, timedelta
from typing import TYPE_CHECKING, Any, Optional

import boto3
Expand Down Expand Up @@ -68,6 +69,9 @@ def __init__(
self.memory_client = MemoryClient(region_name=region_name)
session = boto_session or boto3.Session(region_name=region_name)
self.has_existing_agent = False
self._last_timestamp = None
self._sequence_counter = 0
self._timestamp_lock = threading.Lock()

# Override the clients if custom boto session or config is provided
# Add strands-agents to the request user agent
Expand All @@ -90,6 +94,32 @@ def __init__(
)
super().__init__(session_id=self.config.session_id, session_repository=self)

def _get_monotonic_timestamp(self) -> datetime:
"""Generate a monotonically increasing timestamp with second precision."""
with self._timestamp_lock:
# Currently, boto3 cuts off any granularity beyond seconds in the CreateEvent request. While we wait for a fix to the
# boto3 client, the best we can do to allow concurrent events is to increment seconds on the client.
# TODO: Once boto3 supports sending milliseconds, we should make this code increment milliseconds instead of seconds.
current = datetime.now(timezone.utc).replace(microsecond=0)
print("LOOK AT CURRENT:", current)

if self._last_timestamp is None or current > self._last_timestamp:
self._last_timestamp = current
self._sequence_counter = 0
else:
# Same or earlier time - increment sequence and add seconds
self._sequence_counter += 1
self._last_timestamp = self._last_timestamp + timedelta(seconds=self._sequence_counter)

print("LOOK AT lastTimestamp:", self._last_timestamp)

return self._last_timestamp.replace(microsecond=0)

def _create_event_with_monotonic_timestamp(self, **kwargs):
"""Create event with guaranteed monotonic timestamp."""
kwargs['eventTimestamp'] = self._get_monotonic_timestamp()
return self.memory_client.gmdp_client.create_event(**kwargs)

def _get_full_session_id(self, session_id: str) -> str:
"""Get the full session ID with the configured prefix.

Expand Down Expand Up @@ -142,14 +172,13 @@ def create_session(self, session: Session, **kwargs: Any) -> Session:
if session.session_id != self.config.session_id:
raise SessionException(f"Session ID mismatch: expected {self.config.session_id}, got {session.session_id}")

event = self.memory_client.gmdp_client.create_event(
event = self._create_event_with_monotonic_timestamp(
memoryId=self.config.memory_id,
actorId=self._get_full_session_id(session.session_id),
sessionId=self.session_id,
payload=[
{"blob": json.dumps(session.to_dict())},
],
eventTimestamp=datetime.now(timezone.utc),
)
logger.info("Created session: %s with event: %s", session.session_id, event.get("event", {}).get("eventId"))
return session
Expand Down Expand Up @@ -213,14 +242,13 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A
if session_id != self.config.session_id:
raise SessionException(f"Session ID mismatch: expected {self.config.session_id}, got {session_id}")

event = self.memory_client.gmdp_client.create_event(
event = self._create_event_with_monotonic_timestamp(
memoryId=self.config.memory_id,
actorId=self._get_full_agent_id(session_agent.agent_id),
sessionId=self.session_id,
payload=[
{"blob": json.dumps(session_agent.to_dict())},
],
eventTimestamp=datetime.now(timezone.utc),
)
logger.info(
"Created agent: %s in session: %s with event %s",
Expand Down Expand Up @@ -325,17 +353,16 @@ def create_message(
actor_id=self.config.actor_id,
session_id=session_id,
messages=messages,
event_timestamp=datetime.fromisoformat(session_message.created_at.replace("Z", "+00:00")),
event_timestamp=self._get_monotonic_timestamp(),
)
else:
event = self.memory_client.gmdp_client.create_event(
event = self._create_event_with_monotonic_timestamp(
memoryId=self.config.memory_id,
actorId=self.config.actor_id,
sessionId=session_id,
payload=[
{"blob": json.dumps(messages[0])},
],
eventTimestamp=datetime.fromisoformat(session_message.created_at.replace("Z", "+00:00")),
)
logger.debug("Created event: %s for message: %s", event.get("eventId"), session_message.message_id)
return event
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Tests for AgentCoreMemorySessionManager."""

import threading
from datetime import datetime, timezone
from unittest.mock import Mock, patch

import pytest
Expand Down Expand Up @@ -1047,3 +1049,110 @@ def test_retrieve_customer_context_exception(self, agentcore_config_with_retriev

# Should not raise exception, just log error
manager.retrieve_customer_context(event)


class TestMonotonicTimestamp:
"""Test monotonic timestamp generation."""

def test_monotonic_timestamps_sequential(self, session_manager):
"""Test that sequential calls produce increasing timestamps."""
timestamps = []
for _ in range(10):
timestamps.append(session_manager._get_monotonic_timestamp())

# Verify all timestamps are strictly increasing
for i in range(1, len(timestamps)):
assert timestamps[i] > timestamps[i-1]

def test_monotonic_timestamps_concurrent(self, session_manager):
"""Test that concurrent calls produce unique increasing timestamps."""
timestamps = []
lock = threading.Lock()

def get_timestamp():
ts = session_manager._get_monotonic_timestamp()
with lock:
timestamps.append(ts)

# Create multiple threads
threads = []
for _ in range(20):
thread = threading.Thread(target=get_timestamp)
threads.append(thread)

# Start all threads
for thread in threads:
thread.start()

# Wait for completion
for thread in threads:
thread.join()

# Sort timestamps and verify uniqueness and ordering
timestamps.sort()
assert len(timestamps) == 20
assert len(set(timestamps)) == 20 # All unique

# Verify strictly increasing
for i in range(1, len(timestamps)):
assert timestamps[i] > timestamps[i-1]

def test_monotonic_timestamp_past_minute(self, session_manager):
"""Test that monotonic timestamps can increment past 1 minute."""
import unittest.mock

# Mock datetime.now to return the same time repeatedly
fixed_time = datetime.now(timezone.utc)

with unittest.mock.patch('bedrock_agentcore.memory.integrations.strands.session_manager.datetime') as mock_datetime:
mock_datetime.now.return_value = fixed_time
mock_datetime.timezone = timezone

timestamps = []
for i in range(70):
timestamp = session_manager._get_monotonic_timestamp()
timestamps.append(timestamp)

# Verify timestamps are monotonically increasing
for i in range(1, len(timestamps)):
assert timestamps[i] > timestamps[i-1]

# Verify we can go past 60 seconds
time_diff = timestamps[-1] - timestamps[0]
assert time_diff.total_seconds() >= 60

# Verify we can go past 60 seconds
time_diff = timestamps[-1] - timestamps[0]
assert time_diff.total_seconds() >= 60



def test_create_event_wrapper_uses_monotonic_timestamp(self, session_manager):
"""Test that the create_event wrapper uses monotonic timestamps."""
# Mock the underlying create_event method
session_manager.memory_client.gmdp_client.create_event = Mock(return_value={"eventId": "test-123"})

# Call the wrapper multiple times
session_manager._create_event_with_monotonic_timestamp(
memoryId="test-memory",
actorId="test-actor",
sessionId="test-session",
payload=[{"blob": "test"}]
)

session_manager._create_event_with_monotonic_timestamp(
memoryId="test-memory",
actorId="test-actor",
sessionId="test-session",
payload=[{"blob": "test2"}]
)

# Verify create_event was called twice
assert session_manager.memory_client.gmdp_client.create_event.call_count == 2

# Get the timestamps from both calls
call1_timestamp = session_manager.memory_client.gmdp_client.create_event.call_args_list[0][1]['eventTimestamp']
call2_timestamp = session_manager.memory_client.gmdp_client.create_event.call_args_list[1][1]['eventTimestamp']

# Verify timestamps are monotonically increasing
assert call2_timestamp > call1_timestamp
Loading