diff --git a/pyproject.toml b/pyproject.toml index 2125e76..8919d40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,8 +29,10 @@ dependencies = [ "langchain-community==0.3.20", "json-repair==0.40.0", "Jinja2==3.1.6", - "dspy==2.6.*", - "asteval==1.0.6" + "dspy==2.6.11", + "asteval==1.0.6", + "glom==24.11.0", + "aioboto3==14.1.0" ] [project.optional-dependencies] diff --git a/src/fmcore/aws/factory/boto_factory.py b/src/fmcore/aws/factory/boto_factory.py index 032c18b..ea471f3 100644 --- a/src/fmcore/aws/factory/boto_factory.py +++ b/src/fmcore/aws/factory/boto_factory.py @@ -1,4 +1,7 @@ +from datetime import timezone from typing import Dict + +import aioboto3 import boto3 from botocore.credentials import RefreshableCredentials from botocore.session import get_session @@ -14,7 +17,7 @@ class BotoFactory: @classmethod def __get_refreshable_session(cls, role_arn: str, region: str, session_name: str) -> boto3.Session: """ - Creates a Boto3 session with refreshable credentials for the assumed IAM role. + Creates a botocore session with refreshable credentials for the assumed IAM role. Args: role_arn (str): ARN of the IAM role to assume. @@ -49,43 +52,79 @@ def refresh() -> dict: botocore_session._credentials = refreshable_credentials botocore_session.set_config_variable(AWSConstants.REGION, region) - return boto3.Session(botocore_session=botocore_session) + return botocore_session @classmethod - def __create_session(cls, *, role_arn: str, region: str, session_name: str) -> boto3.Session: + def __create_session(cls, *, role_arn: str = None, region: str, session_name: str) -> boto3.Session: """ Creates a Boto3 session, either using role-based authentication or default credentials. Args: region (str): AWS region for the session. - role_arn (str): IAM role ARN to assume (if provided). + role_arn (str, optional): IAM role ARN to assume. + session_name (str): Name for the session. Returns: boto3.Session: A configured Boto3 session. """ - return ( - cls.__get_refreshable_session(role_arn=role_arn, region=region, session_name=session_name) - if role_arn - else boto3.Session(region_name=region) + if not role_arn: + return boto3.Session(region_name=region) + + # Get a botocore session with refreshable credentials + botocore_session = cls.__get_refreshable_session( + role_arn=role_arn, region=region, session_name=session_name ) + return boto3.Session(botocore_session=botocore_session) + @classmethod - def get_client(cls, *, service_name: str, region: str, role_arn: str) -> boto3.client: + def get_client(cls, *, service_name: str, region: str, role_arn: str = None) -> boto3.client: """ Retrieves a cached Boto3 client or creates a new one. Args: service_name (str): AWS service name (e.g., 's3', 'bedrock-runtime'). region (str): AWS region for the client. - role_arn (str): IAM role ARN for authentication (optional). + role_arn (str, optional): IAM role ARN for authentication. Returns: boto3.client: A configured Boto3 client. """ - key = f"{service_name}-{region}" - session = cls.__create_session(region=region, role_arn=role_arn, session_name=f"{key}-Session") + key = f"{service_name}-{region}-{role_arn or 'default'}" if key not in cls.__clients: + session = cls.__create_session( + region=region, role_arn=role_arn, session_name=f"{service_name}-Session" + ) cls.__clients[key] = session.client(service_name, region_name=region) return cls.__clients[key] + + @classmethod + def get_async_session(cls, *, service_name: str, region: str, role_arn: str = None) -> aioboto3.Session: + session_name: str = f"Async-{service_name}-Session" + + def refresh(): + sts_client = boto3.client("sts", region_name=region) + creds = sts_client.assume_role(RoleArn=role_arn, RoleSessionName=session_name)["Credentials"] + return { + "access_key": creds["AccessKeyId"], + "secret_key": creds["SecretAccessKey"], + "token": creds["SessionToken"], + "expiry_time": creds["Expiration"].astimezone(timezone.utc).isoformat(), + } + + creds = RefreshableCredentials.create_from_metadata( + metadata=refresh(), refresh_using=refresh, method="sts-assume-role" + ) + + frozen = creds.get_frozen_credentials() + + session = aioboto3.Session( + aws_access_key_id=frozen.access_key, + aws_secret_access_key=frozen.secret_key, + aws_session_token=frozen.token, + region_name=region, + ) + + return session diff --git a/src/fmcore/llm/__init__.py b/src/fmcore/llm/__init__.py index 9a31061..ad7abc2 100644 --- a/src/fmcore/llm/__init__.py +++ b/src/fmcore/llm/__init__.py @@ -1,3 +1,4 @@ from fmcore.llm.base_llm import BaseLLM from fmcore.llm.bedrock_llm import BedrockLLM +from fmcore.llm.lambda_llm import LambdaLLM from fmcore.llm.distributed_llm import DistributedLLM \ No newline at end of file diff --git a/src/fmcore/llm/bedrock_llm.py b/src/fmcore/llm/bedrock_llm.py index d31aa94..5ca643b 100644 --- a/src/fmcore/llm/bedrock_llm.py +++ b/src/fmcore/llm/bedrock_llm.py @@ -9,9 +9,10 @@ from fmcore.llm.base_llm import BaseLLM from fmcore.llm.types.llm_types import LLMConfig from fmcore.utils.rate_limit_utils import RateLimiterUtils +from fmcore.utils.retry_utils import RetryUtil -class BedrockLLM(BaseLLM, BaseModel): +class BedrockLLM(BaseLLM[List[BaseMessage], BaseMessage, BaseMessageChunk], BaseModel): """ AWS Bedrock language model with built-in asynchronous rate limiting. @@ -58,6 +59,7 @@ def invoke(self, messages: List[BaseMessage]) -> BaseMessage: """ return self.client.invoke(input=messages) + @RetryUtil.with_backoff(lambda self: self.config.provider_params.retries) async def ainvoke(self, messages: List[BaseMessage]) -> BaseMessage: """ Asynchronously invokes the model with rate limiting. @@ -83,6 +85,7 @@ def stream(self, messages: List[BaseMessage]) -> Iterator[BaseMessageChunk]: """ return self.client.stream(input=messages) + @RetryUtil.with_backoff(lambda self: self.config.provider_params.retries) async def astream(self, messages: List[BaseMessage]) -> AsyncIterator[BaseMessageChunk]: """ Asynchronously streams response chunks from the model with rate limiting. diff --git a/src/fmcore/llm/lambda_llm.py b/src/fmcore/llm/lambda_llm.py new file mode 100644 index 0000000..ff1be68 --- /dev/null +++ b/src/fmcore/llm/lambda_llm.py @@ -0,0 +1,199 @@ +import json +from typing import List, Iterator, AsyncIterator, Dict + +import aioboto3 +from aiolimiter import AsyncLimiter +from botocore.client import BaseClient +from langchain_aws import ChatBedrockConverse +from langchain_community.adapters.openai import convert_dict_to_message +from pydantic import BaseModel +from langchain_core.messages import ( + BaseMessage, + BaseMessageChunk, + convert_to_openai_messages, +) + +from fmcore.aws.factory.boto_factory import BotoFactory +from fmcore.llm.base_llm import BaseLLM +from fmcore.llm.types.llm_types import LLMConfig +from fmcore.llm.types.provider_types import LambdaProviderParams +from fmcore.utils.rate_limit_utils import RateLimiterUtils +from fmcore.utils.retry_utils import RetryUtil + + +class LambdaLLM(BaseLLM[List[BaseMessage], BaseMessage, BaseMessageChunk], BaseModel): + """ + An LLM implementation that routes requests through an AWS Lambda function. + + This class uses both synchronous and asynchronous boto3 Lambda clients to + interact with an LLM hosted via AWS Lambda. It includes automatic async rate + limiting and supports OpenAI-style message formatting. + + Attributes: + sync_client (BaseClient): Boto3 synchronous client for AWS Lambda. + async_session ( aioboto3.Session): Boto3 asynchronous session for AWS Lambda. + rate_limiter (AsyncLimiter): Async limiter to enforce API rate limits. + + Note: + The `async_client` is not stored directly because `aioboto3.client(...)` returns + an asynchronous context manager, which must be used with `async with` and cannot + be reused safely across calls. Instead, we store an `aioboto3.Session` instance + in `async_session`, from which a fresh client is created inside each `async with` + block + + """ + + aliases = ["LAMBDA"] + + sync_client: BaseClient + async_session: aioboto3.Session # Using session here as aioboto3.client returns context manager + rate_limiter: AsyncLimiter + + @classmethod + def _get_instance(cls, *, llm_config: LLMConfig) -> "LambdaLLM": + """ + Factory method to create an instance of LambdaLLM with the given configuration. + + Args: + llm_config (LLMConfig): The LLM configuration, including model and provider details. + + Returns: + LambdaLLM: A configured instance of the Lambda-backed LLM. + """ + provider_params: LambdaProviderParams = llm_config.provider_params + + sync_client = BotoFactory.get_client( + service_name="lambda", + region=provider_params.region, + role_arn=provider_params.role_arn, + ) + async_session = BotoFactory.get_async_session( + service_name="lambda", + region=provider_params.region, + role_arn=provider_params.role_arn, + ) + + rate_limiter = RateLimiterUtils.create_async_rate_limiter( + rate_limit_config=provider_params.rate_limit + ) + + return LambdaLLM( + config=llm_config, sync_client=sync_client, async_session=async_session, rate_limiter=rate_limiter + ) + + def convert_messages_to_lambda_payload(self, messages: List[BaseMessage]) -> Dict: + """ + Converts internal message objects to the payload format expected by the Lambda function. + We expect all lambdas to be accepting openai messages format + + Args: + messages (List[BaseMessage]): List of internal message objects. + + Returns: + Dict: The payload dictionary to send to the Lambda function. + """ + return { + "modelId": self.config.model_id, + "messages": convert_to_openai_messages(messages), + "model_params": self.config.model_params.model_dump(), + } + + def convert_lambda_response_to_messages(self, response: Dict) -> BaseMessage: + """ + Converts the raw Lambda function response into a BaseMessage. + + This method expects the Lambda response to contain a 'Payload' key with a stream + of OpenAI-style messages (a list of dictionaries). It parses the stream, extracts + the first message, and converts it into a BaseMessage instance. + + Args: + response (Dict): The response dictionary returned from the Lambda invocation. + + Returns: + BaseMessage: The first parsed message from the response. + """ + response_payload: List[Dict] = json.load(response["Payload"]) + # The Lambda returns a list of messages in OpenAI format. + # Currently, we only expect a single response message, + # so we take the first item in the list. + return convert_dict_to_message(response_payload[0]) + + def invoke(self, messages: List[BaseMessage]) -> BaseMessage: + """ + Synchronously invokes the Lambda function with given messages. + + Args: + messages (List[BaseMessage]): Input messages for the model. + + Returns: + BaseMessage: Response message from the model. + """ + payload = self.convert_messages_to_lambda_payload(messages) + response = self.sync_client.invoke( + FunctionName=self.config.provider_params.function_arn, + InvocationType="RequestResponse", + Payload=json.dumps(payload), + ) + return self.convert_lambda_response_to_messages(response) + + @RetryUtil.with_backoff(lambda self: self.config.provider_params.retries) + async def ainvoke(self, messages: List[BaseMessage]) -> BaseMessage: + """ + Asynchronously invokes the Lambda function with rate limiting. + + Args: + messages (List[BaseMessage]): Input messages for the model. + + Returns: + BaseMessage: Response message from the model. + """ + async with self.rate_limiter: + async with self.async_session.client("lambda") as lambda_client: + payload = self.convert_messages_to_lambda_payload(messages) + response = await lambda_client.invoke( + FunctionName=self.config.provider_params.function_arn, + InvocationType="RequestResponse", + Payload=json.dumps(payload), + ) + payload = await response["Payload"].read() + response_payload: List[Dict] = json.loads(payload.decode("utf-8")) + # The Lambda returns a list of messages in OpenAI format. + # Currently, we only expect a single response message, + # so we take the first item in the list. + return convert_dict_to_message(response_payload[0]) + + def stream(self, messages: List[BaseMessage]) -> Iterator[BaseMessageChunk]: + """ + Not implemented. Streaming is not supported for LambdaLLM. + + Raises: + NotImplementedError + """ + raise NotImplementedError("Streaming is not implemented for LambdaLLM") + + async def astream(self, messages: List[BaseMessage]) -> AsyncIterator[BaseMessageChunk]: + """ + Not implemented. Asynchronous streaming is not supported for LambdaLLM. + + Raises: + NotImplementedError + """ + raise NotImplementedError("Streaming is not implemented for LambdaLLM") + + def batch(self, messages: List[List[BaseMessage]]) -> List[BaseMessage]: + """ + Not implemented. Batch processing is not supported for LambdaLLM. + + Raises: + NotImplementedError + """ + raise NotImplementedError("Batch processing is not implemented for LambdaLLM.") + + async def abatch(self, messages: List[List[BaseMessage]]) -> List[BaseMessage]: + """ + Not implemented. Asynchronous batch processing is not supported for LambdaLLM. + + Raises: + NotImplementedError + """ + raise NotImplementedError("Batch processing is not implemented for LambdaLLM.") diff --git a/src/fmcore/llm/types/provider_types.py b/src/fmcore/llm/types/provider_types.py index 935b626..e9751c8 100644 --- a/src/fmcore/llm/types/provider_types.py +++ b/src/fmcore/llm/types/provider_types.py @@ -52,3 +52,23 @@ class BedrockProviderParams(BaseProviderParams, AWSAccountMixin, RateLimiterMixi """ aliases = ["BEDROCK"] + + +class LambdaProviderParams(BaseProviderParams, AWSAccountMixin, RateLimiterMixin, RetryConfigMixin): + """ + Configuration for a Bedrock provider using AWS. + + This class combines AWS account settings with request configuration parameters + (such as rate limits and retry policies) needed to interact with Bedrock services. + It mixes in AWS-specific account details, rate limiting, and retry configurations + to form a complete provider setup. + + Mixes in: + AWSAccountMixin: Supplies AWS-specific account details (e.g., role ARN, region). + RateLimiterMixin: Supplies API rate limiting settings. + RetryConfigMixin: Supplies retry policy settings. + """ + + aliases = ["LAMBDA"] + + function_arn: str diff --git a/src/fmcore/types/config_types.py b/src/fmcore/types/config_types.py index 50c3901..b03074e 100644 --- a/src/fmcore/types/config_types.py +++ b/src/fmcore/types/config_types.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, List from bears import FileMetadata from pydantic import Field @@ -30,6 +30,18 @@ class RetryConfig(MutableTyped): max_retries: int = Field(default=3) backoff_factor: float = Field(default=1.0) jitter: float = Field(default=1.0) + retryable_exceptions: List[str] = Field( + default_factory=lambda: [ + "InvalidSignatureException", + "ThrottlingException", + "ModelTimeoutException", + "ServiceUnavailableException", + "ModelNotReadyException", + "ServiceQuotaExceededException", + "ModelErrorException", + "EndpointConnectionError", + ] + ) class DatasetConfig(MutableTyped): diff --git a/src/fmcore/utils/retry_utils.py b/src/fmcore/utils/retry_utils.py new file mode 100644 index 0000000..052d45c --- /dev/null +++ b/src/fmcore/utils/retry_utils.py @@ -0,0 +1,48 @@ +import backoff + +from functools import wraps +from fmcore.types.config_types import RetryConfig +from fmcore.utils.logging_utils import Log + + +class RetryUtil: + """ + Utility class for applying backoff-based retry logic using a custom RetryConfig. + + Methods: + with_backoff(retry_config_getter): Decorator factory that applies retry logic + using the RetryConfig from the instance. + """ + + @staticmethod + def with_backoff(retry_config_getter): + """ + Decorator factory that applies retry logic using the RetryConfig from the class instance. + + Args: + retry_config_getter (Callable): A function that takes `self` and returns a RetryConfig. + + Returns: + Callable: Decorator for the method to be retried. + """ + + def decorator(func): + @wraps(func) + async def wrapper(self, *args, **kwargs): + retry_config: RetryConfig = retry_config_getter(self) + + decorated = backoff.on_exception( + exception=Exception, + wait_gen=backoff.expo, + giveup=lambda e: not any( + exception in str(e) for exception in retry_config.retryable_exceptions + ), + max_tries=retry_config.max_retries, + logger=Log.get_logger(), + )(func) + + return await decorated(self, *args, **kwargs) + + return wrapper + + return decorator diff --git a/tests/local/inference.py b/tests/local/inference.py index 8b20a12..5f1dcf9 100644 --- a/tests/local/inference.py +++ b/tests/local/inference.py @@ -34,6 +34,33 @@ "inference_manager_params": {"num_process": 10}, } +lambda_inference_manager_config = { + "inference_manager_type": "MULTI_PROCESS", + "llm_config": { + "provider_type": "LAMBDA", + "model_id": "mistralai/Mistral-Nemo-Instruct-2407", + "model_params": { + "temperature": 0.5, + "max_tokens": 1024 + }, + "provider_params": { + "role_arn": "arn:aws:iam:::role/", + "function_arn": "arn:aws:lambda:::function:", + "region": "us-west-2", + "rate_limit": { + "max_rate": 10000, + "time_period": 60 + }, + "retries": { + "max_retries": 3 + } + } + }, + "inference_manager_params": { + "num_process": 10 + } +} + # ------------------------------- # Question Generator # ------------------------------- @@ -87,4 +114,5 @@ def run_inference(config_dict, num_questions=100): if __name__ == "__main__": # Choose one: - run_inference(bedrock_config_dict) + #run_inference(bedrock_config_dict) + run_inference(lambda_inference_manager_config)