## Integrating LangChain + Amazon SageMaker

### SageMaker JumpStart Flan T-5 Deployment

Original code retrieved from: https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart-foundation-models/text2text-generation-flan-t5.ipynb

In [None]:
!pip install ipywidgets==7.0.0 --quiet
!pip install --upgrade sagemaker --quiet

In [None]:
model_id, model_version = (
    "huggingface-text2text-flan-t5-xl",
    "*",
)

In [None]:
from sagemaker.jumpstart.model import JumpStartModel


model = JumpStartModel(model_id=model_id, model_version=model_version)

model_predictor = model.deploy()

In [None]:
newline, bold, unbold = "\n", "\033[1m", "\033[0m"


def query_endpoint(encoded_text):
    response = model_predictor.predict(encoded_text)
    return response


def parse_response(query_response):
    generated_text = query_response["generated_text"]
    return generated_text

### Sample Inference

In [None]:
newline, bold, unbold = "\n", "\033[1m", "\033[0m"

text1 = "Translate to German:  My name is Arthur"
text2 = "A step by step recipe to make bolognese pasta:"


for text in [text1, text2]:
    query_response = query_endpoint(text.encode("utf-8"))
    generated_text = parse_response(query_response)
    print(
        f"Inference:{newline}"
        f"input text: {text}{newline}"
        f"generated text: {bold}{generated_text}{unbold}{newline}"
    )

In [None]:
payload = {
    "text_inputs": "Tell me the steps to make a pizza",
    "max_length": 50,
    "num_return_sequences": 3,
    "top_k": 50,
    "top_p": 0.95,
    "do_sample": True,
}

In [None]:
import json
client = boto3.client("runtime.sagemaker")
encoded_payload = json.dumps(payload).encode('utf-8') #JSON serialization
response = client.invoke_endpoint(
        EndpointName=endpoint_name, ContentType="application/json", Body=encoded_payload
    )
model_predictions = json.loads(response["Body"].read())
model_predictions['generated_texts'][0]

### LangChain Integration

LangChain/SageMaker Documentation: https://python.langchain.com/docs/integrations/llms/sagemaker

In [None]:
from langchain import LLMChain
from langchain import SagemakerEndpoint
from langchain.prompts import PromptTemplate
from langchain.llms.sagemaker_endpoint import LLMContentHandler

In [None]:
sample_prompt = "Tell me the steps to make a pizza"
model_params = {"max_length": 100,
                "num_return_sequences": 1,
                "top_k": 100,
                "top_p": .95,
                "do_sample": True}
endpoint_name = "hf-text2text-flan-t5-xl-2023-09-21-18-45-14-357"

In [None]:
# In this instance we are just passing in the question for the prompt for our chain
prompt_template = """{question}"""

prompt = PromptTemplate(
    template=prompt_template, input_variables=["question"]
)

In [None]:
"""
        How our input payload needs to look:
        payload = {
        "text_inputs": "Tell me the steps to make a pizza",
        "max_length": 100,
        "num_return_sequences": 1,
        "top_k": 100,
        "top_p": .95,
        "do_sample": True,
        }
"""


class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"
    
    def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
        input_str = json.dumps({"text_inputs": prompt, **model_kwargs}).encode('utf-8')
        return input_str

    def transform_output(self, output: str) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json["generated_texts"][0]

content_handler = ContentHandler()

In [None]:
llm = SagemakerEndpoint(
        endpoint_name=endpoint_name,
        region_name="us-east-1",
        model_kwargs=model_params,
        content_handler=content_handler,
    )

In [None]:
chain = LLMChain(
        llm=llm, prompt=prompt)

In [None]:
chain.run(sample_prompt)