In [43]:
"""
Here we define the functionality to interact with endpoint. 
we use different function for handling streaming response as the output format is different.
define "endpoint_name" variable below based on the cloudformation stack output.
"""

import boto3
import json

sagemaker_runtime = boto3.client('sagemaker-runtime', region_name='us-east-1')
endpoint_name='llmcpp-llama-2-7b-chat-llama-2-7b-chat-arm-Endpoint'

def invoke_sagemaker_endpoint(endpoint_name, llama_args):
    payload = {
        'inference': True,
        'configure': False,
        'args': llama_args
    }
    response = sagemaker_runtime.invoke_endpoint(
        EndpointName=endpoint_name,
        Body=json.dumps(llama_args),
        ContentType='application/json',
    )
    response_body = json.loads(response['Body'].read().decode())
    return response_body

def invoke_sagemaker_streaming_endpoint(endpoint_name, payload):
    response = sagemaker_runtime.invoke_endpoint_with_response_stream(
        EndpointName=endpoint_name,
        Body=json.dumps(payload),
        ContentType='application/json',
    )    
    event_stream = response['Body']
    for line in event_stream:
        itm = line['PayloadPart']['Bytes'][6:]
        try:
            res = json.loads(itm, strict=False )
            print(res["choices"][0]["text"], end='')
        except:
            #non-valid json, e.g. empty token 
            pass


In [None]:
"""
Non-streaming inference example.   
"""


llama_args = {
    "prompt": "What are top 10 destinations to visit in Europe?",
    "max_tokens": 128,
    "temperature": 0.1,
    "repeat_penalty":1.5,
    "frequency_penalty":1.1,
    "top_p": 0.5
}

inference = invoke_sagemaker_endpoint(endpoint_name,llama_args)
inference['choices'][0]['text']

In [None]:
"""
Streaming inference example
to enable streaming mode, set stream=True
"""

llama_args = {
    "prompt": "What are top 10 destinations to visit in Europe?",
    "max_tokens": 300,
    "temperature": 0.1,
    "repeat_penalty":1.5,
    "frequency_penalty":1.1,
    "top_p": 0.5,
    "stream": True
}

invoke_sagemaker_streaming_endpoint(endpoint_name,llama_args)