Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ dependencies = [
"langchain-community==0.3.20",
"json-repair==0.40.0",
"Jinja2==3.1.6",
"dspy==2.6.*",
"asteval==1.0.6"
"dspy==2.6.11",
"asteval==1.0.6",
"glom==24.11.0",
"aioboto3==14.1.0"
]

[project.optional-dependencies]
Expand Down
63 changes: 51 additions & 12 deletions src/fmcore/aws/factory/boto_factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from datetime import timezone
from typing import Dict

import aioboto3
import boto3
from botocore.credentials import RefreshableCredentials
from botocore.session import get_session
Expand All @@ -14,7 +17,7 @@ class BotoFactory:
@classmethod
def __get_refreshable_session(cls, role_arn: str, region: str, session_name: str) -> boto3.Session:
"""
Creates a Boto3 session with refreshable credentials for the assumed IAM role.
Creates a botocore session with refreshable credentials for the assumed IAM role.

Args:
role_arn (str): ARN of the IAM role to assume.
Expand Down Expand Up @@ -49,43 +52,79 @@ def refresh() -> dict:
botocore_session._credentials = refreshable_credentials
botocore_session.set_config_variable(AWSConstants.REGION, region)

return boto3.Session(botocore_session=botocore_session)
return botocore_session

@classmethod
def __create_session(cls, *, role_arn: str, region: str, session_name: str) -> boto3.Session:
def __create_session(cls, *, role_arn: str = None, region: str, session_name: str) -> boto3.Session:
"""
Creates a Boto3 session, either using role-based authentication or default credentials.

Args:
region (str): AWS region for the session.
role_arn (str): IAM role ARN to assume (if provided).
role_arn (str, optional): IAM role ARN to assume.
session_name (str): Name for the session.

Returns:
boto3.Session: A configured Boto3 session.
"""
return (
cls.__get_refreshable_session(role_arn=role_arn, region=region, session_name=session_name)
if role_arn
else boto3.Session(region_name=region)
if not role_arn:
return boto3.Session(region_name=region)

# Get a botocore session with refreshable credentials
botocore_session = cls.__get_refreshable_session(
role_arn=role_arn, region=region, session_name=session_name
)

return boto3.Session(botocore_session=botocore_session)

@classmethod
def get_client(cls, *, service_name: str, region: str, role_arn: str) -> boto3.client:
def get_client(cls, *, service_name: str, region: str, role_arn: str = None) -> boto3.client:
"""
Retrieves a cached Boto3 client or creates a new one.

Args:
service_name (str): AWS service name (e.g., 's3', 'bedrock-runtime').
region (str): AWS region for the client.
role_arn (str): IAM role ARN for authentication (optional).
role_arn (str, optional): IAM role ARN for authentication.

Returns:
boto3.client: A configured Boto3 client.
"""
key = f"{service_name}-{region}"
session = cls.__create_session(region=region, role_arn=role_arn, session_name=f"{key}-Session")
key = f"{service_name}-{region}-{role_arn or 'default'}"

if key not in cls.__clients:
session = cls.__create_session(
region=region, role_arn=role_arn, session_name=f"{service_name}-Session"
)
cls.__clients[key] = session.client(service_name, region_name=region)

return cls.__clients[key]

@classmethod
def get_async_session(cls, *, service_name: str, region: str, role_arn: str = None) -> aioboto3.Session:
session_name: str = f"Async-{service_name}-Session"

def refresh():
sts_client = boto3.client("sts", region_name=region)
creds = sts_client.assume_role(RoleArn=role_arn, RoleSessionName=session_name)["Credentials"]
return {
"access_key": creds["AccessKeyId"],
"secret_key": creds["SecretAccessKey"],
"token": creds["SessionToken"],
"expiry_time": creds["Expiration"].astimezone(timezone.utc).isoformat(),
}

creds = RefreshableCredentials.create_from_metadata(
metadata=refresh(), refresh_using=refresh, method="sts-assume-role"
)

frozen = creds.get_frozen_credentials()

session = aioboto3.Session(
aws_access_key_id=frozen.access_key,
aws_secret_access_key=frozen.secret_key,
aws_session_token=frozen.token,
region_name=region,
)

return session
1 change: 1 addition & 0 deletions src/fmcore/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from fmcore.llm.base_llm import BaseLLM
from fmcore.llm.bedrock_llm import BedrockLLM
from fmcore.llm.lambda_llm import LambdaLLM
from fmcore.llm.distributed_llm import DistributedLLM
5 changes: 4 additions & 1 deletion src/fmcore/llm/bedrock_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from fmcore.llm.base_llm import BaseLLM
from fmcore.llm.types.llm_types import LLMConfig
from fmcore.utils.rate_limit_utils import RateLimiterUtils
from fmcore.utils.retry_utils import RetryUtil


class BedrockLLM(BaseLLM, BaseModel):
class BedrockLLM(BaseLLM[List[BaseMessage], BaseMessage, BaseMessageChunk], BaseModel):
"""
AWS Bedrock language model with built-in asynchronous rate limiting.

Expand Down Expand Up @@ -58,6 +59,7 @@ def invoke(self, messages: List[BaseMessage]) -> BaseMessage:
"""
return self.client.invoke(input=messages)

@RetryUtil.with_backoff(lambda self: self.config.provider_params.retries)
async def ainvoke(self, messages: List[BaseMessage]) -> BaseMessage:
"""
Asynchronously invokes the model with rate limiting.
Expand All @@ -83,6 +85,7 @@ def stream(self, messages: List[BaseMessage]) -> Iterator[BaseMessageChunk]:
"""
return self.client.stream(input=messages)

@RetryUtil.with_backoff(lambda self: self.config.provider_params.retries)
async def astream(self, messages: List[BaseMessage]) -> AsyncIterator[BaseMessageChunk]:
"""
Asynchronously streams response chunks from the model with rate limiting.
Expand Down
199 changes: 199 additions & 0 deletions src/fmcore/llm/lambda_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
import json
from typing import List, Iterator, AsyncIterator, Dict

import aioboto3
from aiolimiter import AsyncLimiter
from botocore.client import BaseClient
from langchain_aws import ChatBedrockConverse
from langchain_community.adapters.openai import convert_dict_to_message
from pydantic import BaseModel
from langchain_core.messages import (
BaseMessage,
BaseMessageChunk,
convert_to_openai_messages,
)

from fmcore.aws.factory.boto_factory import BotoFactory
from fmcore.llm.base_llm import BaseLLM
from fmcore.llm.types.llm_types import LLMConfig
from fmcore.llm.types.provider_types import LambdaProviderParams
from fmcore.utils.rate_limit_utils import RateLimiterUtils
from fmcore.utils.retry_utils import RetryUtil


class LambdaLLM(BaseLLM[List[BaseMessage], BaseMessage, BaseMessageChunk], BaseModel):
"""
An LLM implementation that routes requests through an AWS Lambda function.

This class uses both synchronous and asynchronous boto3 Lambda clients to
interact with an LLM hosted via AWS Lambda. It includes automatic async rate
limiting and supports OpenAI-style message formatting.

Attributes:
sync_client (BaseClient): Boto3 synchronous client for AWS Lambda.
async_session ( aioboto3.Session): Boto3 asynchronous session for AWS Lambda.
rate_limiter (AsyncLimiter): Async limiter to enforce API rate limits.

Note:
The `async_client` is not stored directly because `aioboto3.client(...)` returns
an asynchronous context manager, which must be used with `async with` and cannot
be reused safely across calls. Instead, we store an `aioboto3.Session` instance
in `async_session`, from which a fresh client is created inside each `async with`
block

"""

aliases = ["LAMBDA"]

sync_client: BaseClient
async_session: aioboto3.Session # Using session here as aioboto3.client returns context manager
rate_limiter: AsyncLimiter

@classmethod
def _get_instance(cls, *, llm_config: LLMConfig) -> "LambdaLLM":
"""
Factory method to create an instance of LambdaLLM with the given configuration.

Args:
llm_config (LLMConfig): The LLM configuration, including model and provider details.

Returns:
LambdaLLM: A configured instance of the Lambda-backed LLM.
"""
provider_params: LambdaProviderParams = llm_config.provider_params

sync_client = BotoFactory.get_client(
service_name="lambda",
region=provider_params.region,
role_arn=provider_params.role_arn,
)
async_session = BotoFactory.get_async_session(
service_name="lambda",
region=provider_params.region,
role_arn=provider_params.role_arn,
)

rate_limiter = RateLimiterUtils.create_async_rate_limiter(
rate_limit_config=provider_params.rate_limit
)

return LambdaLLM(
config=llm_config, sync_client=sync_client, async_session=async_session, rate_limiter=rate_limiter
)

def convert_messages_to_lambda_payload(self, messages: List[BaseMessage]) -> Dict:
"""
Converts internal message objects to the payload format expected by the Lambda function.
We expect all lambdas to be accepting openai messages format

Args:
messages (List[BaseMessage]): List of internal message objects.

Returns:
Dict: The payload dictionary to send to the Lambda function.
"""
return {
"modelId": self.config.model_id,
"messages": convert_to_openai_messages(messages),
"model_params": self.config.model_params.model_dump(),
}

def convert_lambda_response_to_messages(self, response: Dict) -> BaseMessage:
"""
Converts the raw Lambda function response into a BaseMessage.

This method expects the Lambda response to contain a 'Payload' key with a stream
of OpenAI-style messages (a list of dictionaries). It parses the stream, extracts
the first message, and converts it into a BaseMessage instance.

Args:
response (Dict): The response dictionary returned from the Lambda invocation.

Returns:
BaseMessage: The first parsed message from the response.
"""
response_payload: List[Dict] = json.load(response["Payload"])
# The Lambda returns a list of messages in OpenAI format.
# Currently, we only expect a single response message,
# so we take the first item in the list.
return convert_dict_to_message(response_payload[0])

def invoke(self, messages: List[BaseMessage]) -> BaseMessage:
"""
Synchronously invokes the Lambda function with given messages.

Args:
messages (List[BaseMessage]): Input messages for the model.

Returns:
BaseMessage: Response message from the model.
"""
payload = self.convert_messages_to_lambda_payload(messages)
response = self.sync_client.invoke(
FunctionName=self.config.provider_params.function_arn,
InvocationType="RequestResponse",
Payload=json.dumps(payload),
)
return self.convert_lambda_response_to_messages(response)

@RetryUtil.with_backoff(lambda self: self.config.provider_params.retries)
async def ainvoke(self, messages: List[BaseMessage]) -> BaseMessage:
"""
Asynchronously invokes the Lambda function with rate limiting.

Args:
messages (List[BaseMessage]): Input messages for the model.

Returns:
BaseMessage: Response message from the model.
"""
async with self.rate_limiter:
async with self.async_session.client("lambda") as lambda_client:
payload = self.convert_messages_to_lambda_payload(messages)
response = await lambda_client.invoke(
FunctionName=self.config.provider_params.function_arn,
InvocationType="RequestResponse",
Payload=json.dumps(payload),
)
payload = await response["Payload"].read()
response_payload: List[Dict] = json.loads(payload.decode("utf-8"))
# The Lambda returns a list of messages in OpenAI format.
# Currently, we only expect a single response message,
# so we take the first item in the list.
return convert_dict_to_message(response_payload[0])

def stream(self, messages: List[BaseMessage]) -> Iterator[BaseMessageChunk]:
"""
Not implemented. Streaming is not supported for LambdaLLM.

Raises:
NotImplementedError
"""
raise NotImplementedError("Streaming is not implemented for LambdaLLM")

async def astream(self, messages: List[BaseMessage]) -> AsyncIterator[BaseMessageChunk]:
"""
Not implemented. Asynchronous streaming is not supported for LambdaLLM.

Raises:
NotImplementedError
"""
raise NotImplementedError("Streaming is not implemented for LambdaLLM")

def batch(self, messages: List[List[BaseMessage]]) -> List[BaseMessage]:
"""
Not implemented. Batch processing is not supported for LambdaLLM.

Raises:
NotImplementedError
"""
raise NotImplementedError("Batch processing is not implemented for LambdaLLM.")

async def abatch(self, messages: List[List[BaseMessage]]) -> List[BaseMessage]:
"""
Not implemented. Asynchronous batch processing is not supported for LambdaLLM.

Raises:
NotImplementedError
"""
raise NotImplementedError("Batch processing is not implemented for LambdaLLM.")
20 changes: 20 additions & 0 deletions src/fmcore/llm/types/provider_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,23 @@ class BedrockProviderParams(BaseProviderParams, AWSAccountMixin, RateLimiterMixi
"""

aliases = ["BEDROCK"]


class LambdaProviderParams(BaseProviderParams, AWSAccountMixin, RateLimiterMixin, RetryConfigMixin):
"""
Configuration for a Bedrock provider using AWS.

This class combines AWS account settings with request configuration parameters
(such as rate limits and retry policies) needed to interact with Bedrock services.
It mixes in AWS-specific account details, rate limiting, and retry configurations
to form a complete provider setup.

Mixes in:
AWSAccountMixin: Supplies AWS-specific account details (e.g., role ARN, region).
RateLimiterMixin: Supplies API rate limiting settings.
RetryConfigMixin: Supplies retry policy settings.
"""

aliases = ["LAMBDA"]

function_arn: str
Loading