# Deploy Falcon 7B on Amazon SageMaker using Hugging Face Text Generation Inference (TGI) container

## Resources
- [Falcon-7B model card](https://huggingface.co/tiiuae/falcon-7b)
- [TGI documentation](https://huggingface.co/docs/text-generation-inference/en/index)

## Step 1: Setup

In [None]:
%pip install --upgrade --quiet sagemaker

In [None]:
import sagemaker
import json
print(f"sagemaker version: {sagemaker.__version__}")

In [None]:
role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts
region = sess._region_name

## Step 2: Endpoint Deployment

In [None]:
from sagemaker.huggingface import get_huggingface_llm_image_uri

# retrieve the llm image uri
latest_version = "1.4.0" 
llm_image = get_huggingface_llm_image_uri("huggingface", version=latest_version)
print(f"llm image uri: {llm_image}")

In [None]:
from sagemaker.huggingface import HuggingFaceModel

# sagemaker config
instance_type = "ml.g5.2xlarge"
number_of_gpu = 1
health_check_timeout = 600

# TGI config
config = {
    'HF_MODEL_ID': "tiiuae/falcon-7b", # model_id from hf.co/models
    'SM_NUM_GPUS': json.dumps(number_of_gpu), # Number of GPU used per replica
    'MAX_INPUT_LENGTH': json.dumps(1024),  # Max length of input text
    'MAX_TOTAL_TOKENS': json.dumps(2048),  # Max length of the generation (including input text)
}

# create HuggingFaceModel
llm_model = HuggingFaceModel(
    role = role,
    image_uri = llm_image,
    env = config
)

In [None]:
# Deploy model to an endpoint
# https://sagemaker.readthedocs.io/en/stable/api/inference/model.html#sagemaker.model.Model.deploy
llm = llm_model.deploy(
    initial_instance_count = 1,
    instance_type = instance_type,
    container_startup_health_check_timeout = health_check_timeout, # timeout for loading the model
)

## Step 3: Run Inference

In [None]:
# define payload
prompt = """You are an helpful Assistant, called Falcon. Knowing everyting about AWS.

User: Can you tell me something about Amazon SageMaker?
Falcon:"""

# hyperparameters for llm (remove "\nUser:" from stop conditions)
payload = {
  "inputs": prompt,
  "parameters": {
    "do_sample": True,
    "top_p": 0.9,
    "temperature": 0.8,
    "max_new_tokens": 512,
    "repetition_penalty": 1.03,
    "stop": ["<|endoftext|>","</s>"]
  }
}

# send request to endpoint
response = llm.predict(payload)

# print assistant respond
assistant = response[0]["generated_text"][len(prompt):]
print(assistant)

## Step 4: Cleanup

In [None]:
llm.delete_model()
llm.delete_endpoint()