In [None]:
!pip install langchain

In [None]:
import json
import os
import uuid
import time

from langchain.chains import ConversationalRetrievalChain
from langchain import SagemakerEndpoint
from langchain.prompts.prompt import PromptTemplate
from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
from langchain.llms.sagemaker_endpoint import ContentHandlerBase, LLMContentHandler
from langchain.memory import ConversationBufferWindowMemory
from langchain import PromptTemplate, LLMChain
from langchain.memory.chat_message_histories import DynamoDBChatMessageHistory
from langchain.retrievers import AmazonKendraRetriever
from langchain.callbacks.base import BaseCallbackHandler


In [None]:
REGION = "us-east-1"
KENDRA_INDEX_ID = "8b2f0a8c-0bdd-4322-a9ce-e971f65de06e"
SM_ENDPOINT_NAME = "jumpstart-dft-hf-llm-falcon-40b-instruct-bf16-g512xlarge"
#SM_ENDPOINT_NAME = "jumpstart-dft-hf-llm-falcon-40b-instruct-bf16-g548xlarge"
#SM_ENDPOINT_NAME = "jumpstart-dft-hf-llm-falcon-40b-instruct-bf16-p4d24xlarge"



In [None]:
# Content Handler for Option 2 - Falcon40b-instruct - please uncomment below if you used this option
class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt, model_kwargs):
        input_str = json.dumps({"inputs": prompt, "parameters": {"do_sample": False, "repetition_penalty": 1.1, "return_full_text": False, "max_new_tokens":200}})
        return input_str.encode('utf-8')
    
    def transform_output(self, output):
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json[0]["generated_text"]

content_handler = ContentHandler()

In [None]:
# CallbackHandler for measuring latency
class LatencyHandler(BaseCallbackHandler):
    """Base callback handler that can be used to handle callbacks from langchain."""
    def on_llm_start(
        self, serialized, prompts, **kwargs
    ):
        """Run when LLM starts running."""
        self.start_time = time.time()
   

    def on_llm_end(self, response, **kwargs):
        """Run when LLM ends running."""
        self.end_time = time.time()
        self.time_take_by_llm_to_generate_text = self.end_time - self.start_time
        print(f'Inference latency: {self.time_take_by_llm_to_generate_text}')
        
lh = LatencyHandler()

In [None]:
llm=SagemakerEndpoint(
    endpoint_name=SM_ENDPOINT_NAME,
#    model_kwargs=kwargs,
    region_name=REGION,
    content_handler=content_handler, 
    callbacks=[lh]
)

In [None]:
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language. 

Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)

In [None]:
message_history = DynamoDBChatMessageHistory(table_name="MemoryTable", session_id=str(uuid.uuid4().int))
memory = ConversationBufferWindowMemory(memory_key="chat_history", chat_memory=message_history, return_messages=True, k=3)

In [None]:
retriever = AmazonKendraRetriever(
        index_id=KENDRA_INDEX_ID,
        region_name=REGION,
        top_k=2
    )

In [None]:
qa = ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever, memory=memory, condense_question_prompt=CONDENSE_QUESTION_PROMPT, verbose=True)

In [None]:
qa.run("What is Amazon EC2?")

In [None]:
qa.run("Cool, how is the pricing?")

In [None]:
qa.run("How can I provision one?")