# Client code for using Solar Mini Chat SageMaker Endpoint 

This notebook demonstrates how to invoke SageMaker inference endpoint.

## 1. Set AWS credentials

In [None]:
import os

# set your aws credential
# os.environ['AWS_ACCESS_KEY_ID']='xxx'
# os.environ['AWS_SECRET_ACCESS_KEY']='xxx'
# os.environ['AWS_DEFAULT_REGION']='xxx'

# or aws profile
os.environ["AWS_PROFILE"] = "YOUR_AWS_PROFILE"

## 2. Prepare input

In [None]:
import boto3

# Create a low-level client representing Amazon SageMaker Runtime
sagemaker_runtime = boto3.client("sagemaker-runtime")
endpoint_name = "YOUR_ENDPOINT_NAME"

## 3. Invoke endpoint

### 3.1. Stream mode

In [None]:
!pip3 install sseclient-py

In [None]:
request_body = {
    "messages": [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "What is Large Language Model?"},
    ],
    "stream": True,
}

In [None]:
def stream_invoke(endpoint_name, request_body):
    response = sagemaker_runtime.invoke_endpoint_with_response_stream(
        EndpointName=endpoint_name,
        Body=json.dumps(request_body),
        ContentType="application/json",
    )

    for event in response["Body"]:
        yield event["PayloadPart"]["Bytes"]

In [None]:
import json
import sseclient

response = stream_invoke(endpoint_name, request_body)

client = sseclient.SSEClient(response)
for event in client.events():
    if event.data == "[DONE]":
        break

    data = json.loads(event.data)
    if data.get("choices"):
        print(data["choices"][0]["delta"].get("content", ""), end="")

### 3.2. Non-stream mode

In [None]:
import json

nonstream_body = {
    "messages": [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "What is Large Language Model?"},
    ],
    "stream": False,
}

response = sagemaker_runtime.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType="application/json",
    Body=json.dumps(nonstream_body),
)
result = json.loads(response["Body"].read().decode())
print(result)