diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index 535213bd81c6..9f9f3e8283d3 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -366,7 +366,37 @@ async def async_streaming( import aioboto3 session = aioboto3.Session() - async with session.client("sagemaker-runtime", region_name="us-west-2") as client: + + # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them + aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) + aws_access_key_id = optional_params.pop("aws_access_key_id", None) + aws_region_name = optional_params.pop("aws_region_name", None) + + if aws_access_key_id != None: + # uses auth params passed to completion + # aws_access_key_id is not None, assume user is trying to auth using litellm.completion + _client = session.client( + service_name="sagemaker-runtime", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + region_name=aws_region_name, + ) + else: + # aws_access_key_id is None, assume user is trying to auth using env variables + # boto3 automaticaly reads env variables + + # we need to read region name from env + # I assume majority of users use .env for auth + region_name = ( + get_secret("AWS_REGION_NAME") + or "us-west-2" # default to us-west-2 if user not specified + ) + _client = session.client( + service_name="sagemaker-runtime", + region_name=region_name, + ) + + async with _client as client: try: response = await client.invoke_endpoint_with_response_stream( EndpointName=model, @@ -395,7 +425,37 @@ async def async_completion( import aioboto3 session = aioboto3.Session() - async with session.client("sagemaker-runtime", region_name="us-west-2") as client: + + # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them + aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) + aws_access_key_id = optional_params.pop("aws_access_key_id", None) + aws_region_name = optional_params.pop("aws_region_name", None) + + if aws_access_key_id != None: + # uses auth params passed to completion + # aws_access_key_id is not None, assume user is trying to auth using litellm.completion + _client = session.client( + service_name="sagemaker-runtime", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + region_name=aws_region_name, + ) + else: + # aws_access_key_id is None, assume user is trying to auth using env variables + # boto3 automaticaly reads env variables + + # we need to read region name from env + # I assume majority of users use .env for auth + region_name = ( + get_secret("AWS_REGION_NAME") + or "us-west-2" # default to us-west-2 if user not specified + ) + _client = session.client( + service_name="sagemaker-runtime", + region_name=region_name, + ) + + async with _client as client: ## LOGGING request_str = f""" response = client.invoke_endpoint(