In [None]:
%pip install sagemaker gradio langchain==0.0.245 pypdf weaviate-client --force-reinstall --quiet 
%pip install dependencies/botocore-1.29.162-py3-none-any.whl dependencies/boto3-1.26.162-py3-none-any.whl dependencies/awscli-1.27.162-py3-none-any.whl --force-reinstall --quiet

In [None]:
import langchain
langchain.__version__

In [None]:
# python libraries
import ast
import boto3
from datetime import datetime
import gradio as gr
import json
import os
from sagemaker.huggingface import get_huggingface_llm_image_uri
from sagemaker.huggingface import HuggingFaceModel
from typing import List
import weaviate
import sagemaker

# langchain libraries
from langchain import PromptTemplate
#from langchain.chains import ConversationalRetrievalChain
from langchain_utils.base import ConversationalRetrievalChain
from langchain.llms.sagemaker_endpoint import  SagemakerEndpoint, LLMContentHandler
from langchain.retrievers.weaviate_hybrid_search import WeaviateHybridSearchRetriever

from langsmith import Client

ls_client = Client()
sm_client = boto3.client('sagemaker')

In [None]:
region = boto3.Session().region_name
sagemaker_session = sagemaker.Session()
role = sagemaker_session.get_caller_identity_arn()

<mark>Define the load balancer for the Weaviate instance</mark>

In [None]:
elb_endpoint = ''

In [None]:
wv_client = weaviate.Client(url=f"http://{elb_endpoint}")

<mark>Optional but recommended: provide your langsmith API key</mark>

In [None]:
langsmith_api_key = ''

In [None]:
today = datetime.now().strftime("%Y%m%d")
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = f"QA Chain - {today}"
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["LANGCHAIN_API_KEY"] = langsmith_api_key 

<h1>Deploy SageMaker Endpoint</h1>

In [None]:
model_id = 'OpenAssistant/pythia-12b-sft-v8-7k-steps'
num_gpus = 8
instance_type = 'ml.g5.48xlarge'
health_check_timeout = 600

# retrieve the llm image uri
llm_image = get_huggingface_llm_image_uri(
  "huggingface",
  version="0.8.2"
)

# print ecr image uri
print(f"llm image uri: {llm_image}")

In [None]:
# Define Model and Endpoint configuration parameter
config = {
  'HF_MODEL_ID': model_id,
  'SM_NUM_GPUS': json.dumps(num_gpus), # Number of GPU used per replica
  'MAX_INPUT_LENGTH': json.dumps(20024),  # Max length of input text
  'MAX_TOTAL_TOKENS': json.dumps(20048),  # Max length of the generation (including input text)
  # 'HF_MODEL_QUANTIZE': "bitsandbytes", # comment in to quantize
 'MAX_CONCURRENT_REQUESTS': json.dumps(1) # uncomment to limit OOM errors #https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart-foundation-models/text-generation-falcon.ipynb
}

# create HuggingFaceModel with the image uri
llm_oa_model = HuggingFaceModel(
  role=role,
  image_uri=llm_image,
  env=config,
  #transformers_version="4.30.1"
)

llm_oa_endpoint = llm_oa_model.deploy(
  initial_instance_count=1,
  instance_type=instance_type,
  # volume_size=400, # If using an instance with local SSD storage, volume_size must be None, e.g. p4 but not p3
  container_startup_health_check_timeout=health_check_timeout, # 10 minutes to be able to load the model
)

llm_oa_endpoint_name = llm_oa_endpoint.endpoint_name

<h1>Define Langchain LLM</h1>

In [None]:
# OpenAssistant LLM

oa_parameters = {
    "do_sample": True,
    "top_p": 0.7,
    "temperature": 0.1,
    "top_k": 50,
    "return_full_text": False,
    "max_new_tokens": 500,
    "repetition_penalty": 1.03,
    "stop": ["<|endoftext|>"]
  }

class OAContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs={}) -> bytes:
        input_str = json.dumps({"inputs": prompt, "parameters": oa_parameters, **model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        result = response_json[0]["generated_text"]
        try:
            response = result.split('<|assistant|>')[1]
        except:
            response = result
        return response


llm_oa = SagemakerEndpoint(
    endpoint_name=llm_oa_endpoint_name,
    region_name=region,
    model_kwargs=oa_parameters,
    content_handler=OAContentHandler(),
)

In [None]:
llm_oa("<|prompter|>What day comes after Tuesday?<|endoftext|><|assistant|>")

<h1>Define Weaviate Retriever</h1>

In [None]:
wv_hybrid_retriever = WeaviateHybridSearchRetriever(
    client=wv_client,
    index_name="ManualContent",
    text_key="content",
    attributes=["model_names", "file"],
    create_schema_if_missing=True,
)

In [None]:
wv_hybrid_retriever.get_relevant_documents(
    query="how do I unlock the screen?"
)

In [None]:
wv_hybrid_retriever.get_relevant_documents(
    query="how do I unlock the screen?",
    where_filter={
        "path": ["model_names"],
        "operator": "Equal",
        "valueText": "Galaxy S22"
    },
)

<h1>Define QA Chain and Prompt</h1>

In [None]:
# Set Context for response
OA_TEMPLATE = """<|prompter|>You are a tech professional who is an expert in smartphones. The context provided is a portion from a smartphone's user manual.

Use the following context to answer the user's question. Make sure to read all the context before providing an answer.  
Only provide answers that are drawn directly from the context provided.  

If you do not find reference to the question in the provided context, say 'Sorry, I do not find any reference to this question in the provided context'

\nContext:\n{context}\nQuestion: {question}<|endoftext|><|assistant|>
"""


QA_PROMPT = PromptTemplate(template=OA_TEMPLATE, input_variables=["question", "context"])

In [None]:
qa_chain = ConversationalRetrievalChain.from_llm(llm=llm_oa, retriever=wv_hybrid_retriever, verbose=True, return_source_documents=True)

In [None]:
print(qa_chain.combine_docs_chain.llm_chain.prompt.template)

In [None]:
qa_chain.combine_docs_chain.llm_chain.prompt = QA_PROMPT

In [None]:
print(qa_chain.combine_docs_chain.llm_chain.prompt.template)

<h1>Create Gradio App</h1>

In [None]:
# find unique model names
response = (
    wv_client.query
    .get("Manual", ['model_names'])
    .do()
)

# collect uniue model names
model_names = [r['model_names'] for r in response['data']['Get']['Manual']]
model_names = [item for sublist in model_names for item in sublist]
model_names = list(set(model_names))
model_names

In [None]:
with gr.Blocks() as demo:
    gr.Markdown("## Samsung smartphone support service")

    with gr.Column():
        chatbot = gr.Chatbot()
        with gr.Row():
            with gr.Column():
                message = gr.Textbox(label="Chat Message Box", placeholder="Chat Message Box", show_label=False)
                model = gr.Dropdown(model_names, multiselect=False, label="Model")
            with gr.Column():
                with gr.Row():
                    submit = gr.Button("Submit")
                    clear = gr.Button("Clear")
    with gr.Column(visible=False) as resource_col:
        resource_box = gr.Textbox(label="Resources", interactive=False)


    def respond(message, model, chat_history):
        # convert chat history to prompt
        history = []
        if len(chat_history) > 0:
            history = [(h[0], h[1]) for h in history]

        # send request to endpoint
        result = qa_chain({"question": {"question": message, "model": model}, "chat_history": history})
        # parse response
        parsed_response = result['answer']

        sources_list = [r.metadata for r in result['source_documents']]

        sources_str = ''

        files = []
        model_names = []
        print(sources_list)
        for source in sources_list:
            file = source['file']
            model_names.extend(source['model_names'])
            model_names = list(set(model_names))
            if file not in files:
                sources_str += 'FILE: ' + file + '\n'
                files.append(file)

        sources_str += 'MODELS: ' + ','.join(model_names)

        history.append((message, parsed_response))

        return "", history, resource_col.update(visible=True), resource_box.update(value=sources_str)

    submit.click(fn=respond, inputs=[message, model, chatbot], outputs=[message, chatbot, resource_col, resource_box], queue=False)
    clear.click(lambda: None, None, chatbot, queue=False)
demo.launch(share=True)

<h2>Cleanup</h2>

In [None]:
sm_client.delete_endpoint(
    EndpointName=llm_oa_endpoint_name
)