# Deploy on AWS Sagemaker

## Install the necessary packages

In [None]:
!pip install -U sagemaker==2.192.1

## Deploy the Model as A SageMaker Endpoint

In [None]:
import sagemaker
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri
import time

sagemaker_session = sagemaker.Session()
region = sagemaker_session.boto_region_name
role = sagemaker.get_execution_role()

In [None]:
image_uri = get_huggingface_llm_image_uri(
  backend="huggingface", # or lmi
  region=region,
 version="1.1.0"
)

In [None]:
image_uri

In [None]:
model_name = "MistralLite-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

hub = {
    'HF_MODEL_ID':'amazon/MistralLite',
    'HF_TASK':'text-generation',
    'SM_NUM_GPUS':'1',
    "MAX_INPUT_LENGTH": '12000',
    "MAX_TOTAL_TOKENS": '12288',
    "MAX_BATCH_PREFILL_TOKENS": '12288',
    "MAX_BATCH_TOTAL_TOKENS":  '12288',
    "DTYPE": 'bfloat16',
}

model = HuggingFaceModel(
    name=model_name,
    env=hub,
    role=role,
    image_uri=image_uri
)
predictor = model.deploy(
  initial_instance_count=1,
  instance_type="ml.g5.2xlarge",
  endpoint_name=model_name,
    
)

## Perform Inference

### Sagemaker SDK

In [None]:
input_data = {
  "inputs": "<|prompter|>What are the main challenges to support a long context for LLM?</s><|assistant|>",
  "parameters": {
    "do_sample": False,
    "max_new_tokens": 400,
    "return_full_text": False,
    #"typical_p": 0.2,
    #"temperature":None,
    #"truncate":None,
    #"seed": 1,
  }
}
result = predictor.predict(input_data)[0]["generated_text"]
print(result)

### boto3

In [None]:
import boto3
import json
def call_endpoint(client, prompt, endpoint_name, paramters):
    client = boto3.client("sagemaker-runtime")
    payload = {"inputs": prompt,
               "parameters": parameters}
    response = client.invoke_endpoint(EndpointName=endpoint_name,
                                      Body=json.dumps(payload), 
                                      ContentType="application/json")
    output = json.loads(response["Body"].read().decode())
    result = output[0]["generated_text"]
    return result

client = boto3.client("sagemaker-runtime")
parameters = {
    "do_sample": False,
    "max_new_tokens": 400,
    "return_full_text": False,
    #"typical_p": 0.2,
    #"temperature":None,
    #"truncate":None,
    #"seed": 10,
}
endpoint_name = predictor.endpoint_name
prompt = "<|prompter|>What are the main challenges to support a long context for LLM?</s><|assistant|>"
result = call_endpoint(client, prompt, endpoint_name, parameters)
print(result)