# Conversational Search

In this lab, we leverage Lanchain framework to implement Retrieval Augmented Generation solution.For more informaiton about LangChain, please refere: https://python.langchain.com/docs/use_cases/question_answering/

## Step 1: Initialize

Install library

In [None]:
!pip install --upgrade sagemaker 
!pip install opensearch-py
!pip install wikipedia unstructured transformers
!pip install langchain

Initialize SageMaker, Boto3

In [None]:
import sagemaker, boto3, json
from sagemaker.session import Session

sagemaker_session = Session()
aws_role = sagemaker_session.get_caller_identity_arn()
aws_region = boto3.Session().region_name
sess = sagemaker.Session()

Get Cloud Formation stack output variables

We also need to grab some key values from the infrastructure we provisioned using CloudFormation. To do this, we will list the outputs from the stack and store this in "outputs" to be used later.

You can ignore any "PythonDeprecationWarning" warnings.

In [None]:
region = aws_region

cfn = boto3.client('cloudformation')

def get_cfn_outputs(stackname):
    outputs = {}
    for output in cfn.describe_stacks(StackName=stackname)['Stacks'][0]['Outputs']:
        outputs[output['OutputKey']] = output['OutputValue']
    return outputs

## Setup variables to use for the rest of the demo
cloudformation_stack_name = "semantic-search"

outputs = get_cfn_outputs(cloudformation_stack_name)
aos_host = outputs['OpenSearchDomainEndpoint']

outputs

## Step 2 : Verify deployed endpoint for embedding and content generation model

### Get endpoint for embedding

In [None]:
embedding_endpoint_name=outputs['EmbeddingEndpointName']
print(embedding_endpoint_name)

Verify embedding endpoint is ready

In [None]:
import time

sm_client = boto3.client("sagemaker", aws_region)

describe_embedding_endpoint_response = sm_client.describe_endpoint(EndpointName=embedding_endpoint_name)

while describe_embedding_endpoint_response["EndpointStatus"] == 'Creating':
    time.sleep(15)
    print('.', end='')
    describe_embedding_endpoint_response = sm_client.describe_endpoint(EndpointName=embedding_endpoint_name)
print('enmbedding endpoint created')

### Get endpoint for content generation

In [None]:
llm_endpoint_name=outputs['EmbeddingEndpointName']
print(llm_endpoint_name)

Verify embedding endpoint is ready

In [None]:
sm_client = boto3.client("sagemaker", aws_region)

describe_llm_endpoint_response = sm_client.describe_endpoint(EndpointName=llm_endpoint_name)

while describe_llm_endpoint_response["EndpointStatus"] == 'Creating':
    time.sleep(15)
    print('.', end='')
    describe_llm_endpoint_response = sm_client.describe_endpoint(EndpointName=llm_endpoint_name)
print('LLM endpoint created')

## Test embedding endpoint

In [None]:
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
from langchain.embeddings import SagemakerEndpointEmbeddings


class TestContentHandler(EmbeddingsContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs={}) -> bytes:
        input_str = json.dumps({"text_inputs": prompt, **model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        embeddings = response_json["embedding"]
        if len(embeddings) == 1:
            return [embeddings[0]]
        return embeddings


test_content_handler = TestContentHandler()

test_embeddings = SagemakerEndpointEmbeddings(
    endpoint_name=embedding_endpoint_name,
    region_name=aws_region,
    content_handler=test_content_handler,
)

In [None]:
print(test_embeddings.embed_documents(["Hello World"])[0][:5])

## Test LLM endpoint

In [None]:
def query_endpoint_with_json_payload(encoded_json, endpoint_name, content_type="application/json"):
    client = boto3.client("runtime.sagemaker")
    response = client.invoke_endpoint(
        EndpointName=endpoint_name, ContentType=content_type, Body=encoded_json
    )
    return response

#method used to parse the inference model's response. we pass it as part of the model's config
def parse_response_model(query_response):
    model_predictions = json.loads(query_response["Body"].read())
    return [gen["generated_text"] for gen in model_predictions]


In [None]:
question = "How to determine shard and data node counts for OpenSearch?"

payload = {
    "inputs": question,
    "parameters":{
        "max_new_tokens": 1024,
        "num_return_sequences": 1,
        "top_k": 100,
        "top_p": 0.95,
        "do_sample": False,
        "return_full_text": True,
        "temperature": 0.9
    }
}



In [None]:
query_response = query_endpoint_with_json_payload(
    json.dumps(payload).encode("utf-8"), endpoint_name=llm_endpoint_name
)

generated_texts = parse_response_model(query_response)

print(f"The generated output is: {generated_texts[0]}\n")

## Step 3: Load documents with Langchain document loader and store vector into OpenSearch

Use document loaders to load data from a source as Document's. A Document is a piece of text and associated metadata. For example, there are document loaders for loading a simple .txt file, for loading the text contents of any web page, or even for loading a transcript of a YouTube video.

Document loaders expose a "load" method for loading data as documents from a configured source. Here, we use `UnstructuredURLLoader` to load OpenSearch best practice web page.

In [None]:
from langchain.document_loaders import UnstructuredURLLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1000, chunk_overlap = 100)

urls = ["https://docs.aws.amazon.com/opensearch-service/latest/developerguide/bp.html",
        "https://docs.aws.amazon.com/opensearch-service/latest/developerguide/sizing-domains.html", 
        "https://docs.aws.amazon.com/opensearch-service/latest/developerguide/petabyte-scale.html",
        "https://docs.aws.amazon.com/opensearch-service/latest/developerguide/managedomains-dedicatedmasternodes.html",
        "https://docs.aws.amazon.com/opensearch-service/latest/developerguide/cloudwatch-alarms.html"]
url_loader = UnstructuredURLLoader(urls=urls)
url_texts = url_loader.load_and_split(text_splitter=text_splitter)


Somple example documents

In [None]:
all_splits = url_texts
all_splits[0:5]

Create an OpenSearch cluster connection.
Next, we'll use Python API to set up connection with Amazon Opensearch Service domain.

In [None]:
from opensearchpy import OpenSearch, RequestsHttpConnection

auth = ("master","Semantic123!")
aos_client = OpenSearch(
    hosts = [{'host': aos_host, 'port': 443}],
    http_auth = auth,
    use_ssl = True,
    verify_certs = True,
    connection_class = RequestsHttpConnection
)

Use `OpenSearchVectorSearch` in LangChain to ingest vector into OpenSearch. You can specify more parameters to create kNN index with specified properties. Some parameters like:
engine: “nmslib”, “faiss”, “lucene”; default: “nmslib”

space_type: “l2”, “l1”, “cosinesimil”, “linf”, “innerproduct”; default: “l2”

ef_search: Size of the dynamic list used during k-NN searches. Higher values lead to more accurate but slower searches; default: 512

ef_construction: Size of the dynamic list used during k-NN graph creation. Higher values lead to more accurate graph but slower indexing speed; default: 512

m: Number of bidirectional links created for each new element. Large impact on memory consumption. Between 2 and 100; default: 16



### Langchain embedding endpoint

To build a simiplied QA application with LangChain, we need to wrap up our SageMaker endpoints for embedding model and LLM into `langchain.embeddings.SagemakerEndpointEmbeddings` and `langchain.llms.sagemaker_endpoint.SagemakerEndpoint`. That requires a overwrite methods of `SagemakerEndpointEmbeddings` class to make it compatible with SageMaker embedding mdoel.

In [None]:
from typing import Any, Dict, Iterable, List, Optional, Tuple, Callable
import json
from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
from langchain.schema import Document

class BulkSagemakerEndpointEmbeddings(SagemakerEndpointEmbeddings):
        def embed_documents(
            self, texts: List[str], chunk_size: int = 5
        ) -> List[List[float]]:
            """Compute doc embeddings using a SageMaker Inference Endpoint.

            Args:
                texts: The list of texts to embed.
                chunk_size: The chunk size defines how many input texts will
                    be grouped together as request. If None, will use the
                    chunk size specified by the class.

            Returns:
                List of embeddings, one for each text.
            """
            results = []
            _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size

            for i in range(0, len(texts), _chunk_size):
                response = self._embedding_func(texts[i:i + _chunk_size])
                results.extend(response)
            return results
        
class EmbeddingContentHandler(EmbeddingsContentHandler):
        content_type = "application/json"
        accepts = "application/json"

        def transform_input(self, prompt: str, model_kwargs={}) -> bytes:

            input_str = json.dumps({"text_inputs": prompt, **model_kwargs})
            return input_str.encode('utf-8') 

        def transform_output(self, output: bytes) -> str:

            response_json = json.loads(output.read().decode("utf-8"))
            embeddings = response_json["embedding"]
            if len(embeddings) == 1:
                return [embeddings[0]]
            return embeddings
        
embeddings = BulkSagemakerEndpointEmbeddings( 
            endpoint_name=embedding_endpoint_name,
            region_name=aws_region, 
            content_handler=EmbeddingContentHandler())


In [None]:
from langchain.vectorstores import OpenSearchVectorSearch

os_domain_ep = 'https://'+aos_host

embedding_index_name = 'opensearch_kb_vector'

if len(all_splits) > 500:
    for i in range(0, len(all_splits), 500):
        start = i
        end = i+500
        if end > len(all_splits):
            end = len(all_splits)-1
        docs = all_splits[start:end]
        OpenSearchVectorSearch.from_documents(
            index_name = embedding_index_name,
            documents=docs,
            embedding=embeddings,
            opensearch_url=os_domain_ep,
            http_auth=auth
        )
        print(f"ingest documents from {start} to {end}", start, end)
else:
    OpenSearchVectorSearch.from_documents(
            index_name = embedding_index_name,
            documents=all_splits,
            embedding=embeddings,
            opensearch_url=os_domain_ep,
            http_auth=auth
        )
    print(f"ingest documents")

In [None]:
aos_client.indices.get(index=embedding_index_name)

When you use LangChain `OpenSearchVectorSearch` to store embedding with OpenSearch kNN index, you can specify parameters to choose different Appriximate Near Neighbour(ANN) algrithoms. For more information, please refer OpenSearch kNN documentaion: https://opensearch.org/docs/latest/search-plugins/knn/knn-index/

In [None]:
customized_embedding_index_name = 'customized_opensearch_kb_vector'

OpenSearchVectorSearch.from_documents(
            index_name = customized_embedding_index_name,
            documents=all_splits,
            embedding=embeddings,
            opensearch_url=os_domain_ep,
            http_auth=auth,
            engine="faiss",
            space_type="innerproduct",
            ef_construction=256,
            m=48,
        )
print(f"ingest documents into customized knn index")

In [None]:
aos_client.indices.get(index=customized_embedding_index_name)

### OpenSearch vector store

We can use `OpenSearchVectorSearch` for vector store or we can extend the class to define new fuction to calculate documents relevance score if you want to use relevance score to filter document.

In [None]:
class SimiliarOpenSearchVectorSearch(OpenSearchVectorSearch):
    
    def relevance_score(self, distance: float) -> float:
        return distance
    
    def _select_relevance_score_fn(self) -> Callable[[float], float]:
        return self.relevance_score
    

open_search_vector_store = SimiliarOpenSearchVectorSearch(
                                    index_name=embedding_index_name,
                                    embedding_function=embeddings,
                                    opensearch_url=os_domain_ep,
                                    http_auth=auth
                                    ) 

Show the documents which are similiar with question "How to determine shard and data node counts for OpenSearch?". Be default, 4 documents are returned. You can specify "k" parameter. See the [doc](https://api.python.langchain.com/en/latest/vectorstores/langchain.vectorstores.opensearch_vector_search.OpenSearchVectorSearch.html#langchain.vectorstores.opensearch_vector_search.OpenSearchVectorSearch.similarity_search_with_score) for more information.


In [None]:
docs_ = open_search_vector_store.similarity_search_with_score(question, k=5)

print("found document number:" + str(len(docs_)))

print("opensearch results:\n")
for doc in docs_:
    print(doc)
    print("\n-----------------")

In [None]:
from langchain.vectorstores import OpenSearchVectorSearch

os_domain_ep = 'https://'+aos_host

embedding_index_name = 'opensearch_best_practice_embedding'

if len(all_splits) > 500:
    for i in range(0, len(all_splits), 500):
        start = i
        end = i+500
        if end > len(all_splits):
            end = len(all_splits)-1
        docs = all_splits[start:end]
        OpenSearchVectorSearch.from_documents(
            index_name = embedding_index_name,
            documents=docs,
            embedding=embeddings,
            opensearch_url=os_domain_ep,
            http_auth=auth
        )
        print(f"ingest documents from {start} to {end}", start, end)
else:
    OpenSearchVectorSearch.from_documents(
            index_name = embedding_index_name,
            documents=all_splits,
            embedding=embeddings,
            opensearch_url=os_domain_ep,
            http_auth=auth
        )
    print(f"ingest documents")

In [None]:
aos_client.indices.get(index=embedding_index_name)

You will see new index has different setting for vector field. For more information about OpenSearch engine, spacetype etc, please refer OpenSearch documentation: https://opensearch.org/docs/latest/search-plugins/knn/knn-index/

In [None]:
customized_embedding_index_name = 'customized_opensearch_best_practice_embedding'

OpenSearchVectorSearch.from_documents(
            index_name = customized_embedding_index_name,
            documents=all_splits,
            embedding=embeddings,
            opensearch_url=os_domain_ep,
            http_auth=auth,
            engine="faiss",
            space_type="innerproduct",
            ef_construction=256,
            m=48,
        )

In [None]:
aos_client.indices.get(index=customized_embedding_index_name)

## Step 4: Test LLM without context information

In [None]:
from uuid import uuid4
from typing import Dict
from langchain.memory import ConversationBufferMemory
from langchain.memory import DynamoDBChatMessageHistory
from langchain.memory import ConversationBufferWindowMemory
from langchain import PromptTemplate, SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.chains import RetrievalQA


class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"
    
    def transform_input(self, prompt: str, model_kwargs={}) -> bytes:
        input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs})
        print("Prompt Input:\n" + input_str)
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        print("LLM generated text:\n" + response_json[0]["generated_text"])
        return response_json[0]["generated_text"]
    

content_handler = ContentHandler()


In [None]:
params = {
        "max_length": 4096,
        "max_new_tokens": 1024,
        "num_return_sequences": 1,
        "top_k": 100,
        "top_p": 0.95,
        "do_sample": False,
        "return_full_text": False,
        "temperature": 0.9
        }

In [None]:
llm_hullucination=SagemakerEndpoint(
        endpoint_name=llm_endpoint_name,
        region_name=aws_region,
        model_kwargs=params,
        content_handler=content_handler,
)

To better illustrate why we need retrieval-augmented generation (RAG) based approach to solve the question and anwering problem. Let's directly ask the model a question and see how they respond.

In [None]:
print("Question is:" + question)
llm_hullucination(question)

Generated answer is not fully accurate.

In [None]:
llm_40b_endpoint_name = 'RAG-LLM-huggingface-llm-falcon-40b-inst-2023-08-29-08-14-50-272'

llm_40b_hullucination=SagemakerEndpoint(
        endpoint_name=llm_40b_endpoint_name,
        region_name=aws_region,
        model_kwargs=params,
        content_handler=content_handler,
)


In [None]:
print("Question is:" + question)
llm_40b_hullucination(question)

## Step 5: Retrieval Augmented Generation

 ### Langchain retriever

Here we use OpenSearch vector store as retriever to get similiar documents with query. We can also specify similarity scrore threshhold to return high relevant documents. Use "k" to limit how many documents to be returned.

In [None]:
retriever = open_search_vector_store.as_retriever(
    search_type="similarity_score_threshold",
    search_kwargs={
        'k': 5,
        'score_threshold': 0.62
    }
)

### LLM without Hallucination

To avoid LLM generate non factual answer, we can specify low "temperature" value when calling LLM to generate conent. 

In [None]:
params = {
        "max_length": 4096,
        "max_new_tokens": 1024,
        "num_return_sequences": 1,
        "top_k": 100,
        "top_p": 0.95,
        "do_sample": False,
        "return_full_text": False,
        "temperature": 0.001
        }

llm=SagemakerEndpoint(
        endpoint_name=llm_endpoint_name,
        region_name=aws_region,
        model_kwargs=params,
        content_handler=content_handler,
)

In [None]:
qa = RetrievalQA.from_chain_type(
    llm=llm,
    retriever=retriever,
    chain_type="stuff" #stuff, refine, map_reduce, and map_rerank
)

Compare the content generated with RAG and LLM without context.

In [None]:
print("Question is:" + question)
result = qa({"query": question})

print("result:" + result["result"])
  

### Use customized prompt for RAG

You can also customized prompt per your requirements.

In [None]:
template = """Use the following "Context:" to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. Use three sentences maximum and keep the answer as concise as possible. 
Context:{context}
Question: {question}
Helpful Answer:"""

QA_CHAIN_PROMPT = PromptTemplate.from_template(template=template)

qa_with_prompt = RetrievalQA.from_chain_type(
    llm=llm,
    retriever=retriever,
    chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
)


In [None]:
print("Question is:" + question)
result = qa_with_prompt({"query": question})

print("\n### Generated result:" + result["result"])


### Try to ask some questions which are not covered in your knowledge base.

In [None]:
result = qa_with_prompt({"query": "Who is Jianwei?"})

print("\n### Generated result:" + result["result"])

### Return source documents

You can also return the source documents to help you find original knoledge base document.


In [None]:
qa_with_source = RetrievalQA.from_chain_type(
    llm=llm,
    retriever=retriever,
    return_source_documents=True,
    chain_type="refine" #stuff, refine, map_reduce, and map_rerank
)

In [None]:
print("Question is:" + question)
result = qa_with_source({"query": question})

print("result:" + result["result"])
print("\n\n===========================")
print("\nsource documents:")
for doc in result["source_documents"]:
    print(doc)
    print("---------------------------\n")



For `RetrievalQA`, you have 4 methods to use retrieved documents as content, stuff, refine, map_reduce, and map_rerank. Please refere https://python.langchain.com/docs/modules/chains/document/.


## Step 6: Conversational search by memorizing the history 

### Langchain Memory with DynamoDB

Here we create new session and use DynamoDB as backend to store history conversation. 

In [None]:
ddb_table_name = outputs['DynamoDBTableName']
session_id = str(uuid4())
chat_memory = DynamoDBChatMessageHistory(
        table_name=ddb_table_name,
        session_id=session_id
    )

messages = chat_memory.messages

# Maintains immutable sessions
# If previous session was present, create
# a new session and copy messages, and 
# generate a new session_id 
if messages:
    session_id = str(uuid4())
    chat_memory = DynamoDBChatMessageHistory(
        table_name="conversation-history-store",
        session_id=session_id
    )
    # This is a workaround at the moment. Ideally, this should
    # be added to the DynamoDBChatMessageHistory class
    try:
        messages = messages_to_dict(messages)
        chat_memory.table.put_item(
            Item={"SessionId": session_id, "History": messages}
        )
    except Exception as e:
        print(e)

memory = ConversationBufferMemory(chat_memory=chat_memory, return_messages=True)


In [None]:
qa_with_memory = RetrievalQA.from_chain_type(
    llm=llm,
    retriever=retriever,
    memory = memory,
    chain_type="stuff" #stuff, refine, map_reduce, and map_rerank
)

In [None]:
result = qa_with_memory({"query": question})
print("result:" + result["result"])
print("\nHistory:\n===========================")
for x in range(0,len(result["history"]),2):
    print("Question:")
    print(result["history"][x])
    print("Answer:")
    print(result["history"][x+1])
    print("---------------------------")

Try to ask one more question, the history conversation stored in DynamoDB are also used as context to LLM.

In [None]:
second_following_question = 'if my data growth is very fast'
result = qa_with_memory({"query": second_following_question})
print("result:" + result["result"])
print("\nHistory:\n===========================")
for x in range(0,len(result["history"]),2):
    print("Question:")
    print(result["history"][x])
    print("Answer:")
    print(result["history"][x+1])
    print("---------------------------")

# Deploy Model

## Deploy embedding model

In [None]:
embedding_model_id, embedding_model_version = (
    "huggingface-textembedding-gpt-j-6b-fp16",
    "*",
)

In [None]:
from sagemaker import image_uris, model_uris, script_uris, hyperparameters
from sagemaker.model import Model
from sagemaker.predictor import Predictor
from sagemaker.utils import name_from_base


In [None]:
embedding_endpoint_name = name_from_base(f"RAG-embedding-{embedding_model_id}")

embedding_instance_type = "ml.g5.2xlarge"

# Retrieve the inference docker container uri. This is the base HuggingFace container image for the default model above.

embedding_deploy_image_uri = image_uris.retrieve(
    region=None,
    framework=None,  # automatically inferred from model_id
    image_scope="inference",
    model_id=embedding_model_id,
    model_version=embedding_model_version,
    instance_type=embedding_instance_type,
)

# Retrieve the model uri.
embedding_model_uri = model_uris.retrieve(
    model_id=embedding_model_id, model_version=embedding_model_version, model_scope="inference"
)


embedding_model = Model(
    image_uri=embedding_deploy_image_uri,
    model_data=embedding_model_uri,
    role=aws_role,
    predictor_cls=Predictor,
    name=embedding_endpoint_name,
)

# deploy the Model. Note that we need to pass Predictor class when we deploy model through Model class,
# for being able to run inference through the sagemaker API.
embedding_model_predictor = embedding_model.deploy(
    initial_instance_count=1,
    instance_type=embedding_instance_type,
    predictor_cls=Predictor,
    endpoint_name=embedding_endpoint_name,
    wait=False
)

## Deploy content generation model

In [None]:
llm_model_id, llm_model_version, = (
    "huggingface-llm-falcon-7b-instruct-bf16",
    "*",
)

llm_model_ids = ['huggingface-llm-falcon-40b-bf16',
             'huggingface-llm-falcon-40b-instruct-bf16',
             'huggingface-llm-falcon-7b-bf16',
             'huggingface-llm-falcon-7b-instruct-bf16']

# display the model-ids in a dropdown to select a model for inference.
model_dropdown = Dropdown(
    options=llm_model_ids,
    value=llm_model_id,
    description="Select a model",
    style={"description_width": "initial"},
    layout={"width": "max-content"},
)
display(model_dropdown)

In [None]:
llm_model_id = model_dropdown.value
print(llm_model_id)

In [None]:
llm_endpoint_name = name_from_base(f"RAG-LLM-{llm_model_id}")

llm_inference_instance_type = "ml.g5.48xlarge"

health_check_timeout = 1800



In [None]:
from sagemaker.jumpstart.model import JumpStartModel

llm_model = JumpStartModel(model_id=llm_model_id, instance_type=llm_inference_instance_type)
llm_model.env['SM_NUM_GPUS'] = '8'
llm_model.env['MAX_INPUT_LENGTH'] = '2048'
llm_model.env['MAX_TOTAL_TOKENS'] = '4096'
llm_model_predictor = llm_model.deploy(
    endpoint_name=llm_endpoint_name,
    container_startup_health_check_timeout=health_check_timeout,
    wait=False)

## deploy with more code

In [None]:

llm_endpoint_name = name_from_base(f"RAG-LLM-{llm_model_id}")

llm_deploy_image_uri = image_uris.retrieve(
    region=None,
    framework=None,  # automatically inferred from model_id
    image_scope="inference",
    model_id=llm_model_id,
    model_version=llm_model_version,
    instance_type=llm_inference_instance_type,
)

# Retrieve the model uri.
llm_model_uri = model_uris.retrieve(
    model_id=llm_model_id, model_version=llm_model_version, model_scope="inference"
)

number_of_gpu = 8
max_input_length = 2048
max_total_tokens = 4096

model_env = {
    'HF_MODEL_ID': "tiiuae/falcon-40b-instruct",
    'SM_NUM_GPUS': json.dumps(number_of_gpu),
    'MAX_INPUT_LENGTH': json.dumps(max_input_length),
    'MAX_TOTAL_TOKENS': json.dumps(max_total_tokens),
}


llm_model = Model(
    image_uri=llm_deploy_image_uri,
    model_data=llm_model_uri,
    role=aws_role,
    predictor_cls=Predictor,
    name=llm_endpoint_name,
    env=model_env
)

# deploy the Model. Note that we need to pass Predictor class when we deploy model through Model class,
# for being able to run inference through the sagemaker API.
llm_model_predictor = llm_model.deploy(
    initial_instance_count=1,
    instance_type=llm_inference_instance_type,
    predictor_cls=Predictor,
    endpoint_name=llm_endpoint_name,
    container_startup_health_check_timeout=health_check_timeout,
    wait=False
)