Skip to content

Commit

Permalink
fix(sagemaker.py): fix async sagemaker calls
Browse files Browse the repository at this point in the history
  • Loading branch information
krrishdholakia committed Feb 21, 2024
1 parent 6546b43 commit 49c4aa5
Showing 1 changed file with 62 additions and 2 deletions.
64 changes: 62 additions & 2 deletions litellm/llms/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,37 @@ async def async_streaming(
import aioboto3

session = aioboto3.Session()
async with session.client("sagemaker-runtime", region_name="us-west-2") as client:

# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)

if aws_access_key_id != None:
# uses auth params passed to completion
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
_client = session.client(
service_name="sagemaker-runtime",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
)
else:
# aws_access_key_id is None, assume user is trying to auth using env variables
# boto3 automaticaly reads env variables

# we need to read region name from env
# I assume majority of users use .env for auth
region_name = (
get_secret("AWS_REGION_NAME")
or "us-west-2" # default to us-west-2 if user not specified
)
_client = session.client(
service_name="sagemaker-runtime",
region_name=region_name,
)

async with _client as client:
try:
response = await client.invoke_endpoint_with_response_stream(
EndpointName=model,
Expand Down Expand Up @@ -395,7 +425,37 @@ async def async_completion(
import aioboto3

session = aioboto3.Session()
async with session.client("sagemaker-runtime", region_name="us-west-2") as client:

# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)

if aws_access_key_id != None:
# uses auth params passed to completion
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
_client = session.client(
service_name="sagemaker-runtime",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
)
else:
# aws_access_key_id is None, assume user is trying to auth using env variables
# boto3 automaticaly reads env variables

# we need to read region name from env
# I assume majority of users use .env for auth
region_name = (
get_secret("AWS_REGION_NAME")
or "us-west-2" # default to us-west-2 if user not specified
)
_client = session.client(
service_name="sagemaker-runtime",
region_name=region_name,
)

async with _client as client:
## LOGGING
request_str = f"""
response = client.invoke_endpoint(
Expand Down

0 comments on commit 49c4aa5

Please sign in to comment.